In [70]:
import pandas as pd
import numpy as np
from datetime import datetime
import json, io, base64, re, os, requests
import plotly.graph_objects as go
from IPython.display import display, HTML, clear_output
import plotly.express as px
import ipywidgets as widgets
from xml.etree import ElementTree
# Initialize settings
import _settings as settings

# Global variables from settings
spec_translate_list = settings.SPEC_TRANSLATE_LIST
plotly_colors = settings.plotly_colors

In [71]:
class DataTransformation:
    def __init__(self):
        self.merged_df = None
        self.proteins_dic = {}
        self.output_area = None
        self.merged_uploader = None

    def create_download_link(self, file_path, label):
        """Create a download link for a file."""
        if os.path.exists(file_path):
            # Read file content and encode it as base64
            with open(file_path, 'rb') as f:
                content = f.read()
            b64_content = base64.b64encode(content).decode('utf-8')

            # Generate the download link HTML
            return widgets.HTML(f"""
                <a download="{os.path.basename(file_path)}" 
                   href="data:application/octet-stream;base64,{b64_content}" 
                   style="color: #0366d6; text-decoration: none; margin-left: 20px; font-size: 14px;">
                    {label}
                </a>
            """)
        else:
            # Show an error message if the file does not exist
            return widgets.HTML(f"""
                <span style="color: red; margin-left: 20px; font-size: 14px;">
                    File "{file_path}" not found!
                </span>
            """)

    def setup_data_loading_ui(self):
        """Initialize and display the data loading UI."""
        # Create file upload widget
        self.merged_uploader = widgets.FileUpload(
            accept='.csv,.txt,.tsv,.xlsx',
            multiple=False,
            description='Upload Merged Data File',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )

        self.output_area = widgets.Output()

        # Create upload box with example link
        merged_box = widgets.HBox([
            self.merged_uploader,
            self.create_download_link("example_merged_dataframe.csv", "Example")
        ], layout=widgets.Layout(align_items='center'))

        # Create left column with upload widgets
        upload_widgets = widgets.VBox([
            widgets.HTML("<h4>Upload Data File:</h4>"),
            merged_box,
            self.output_area
        ], layout=widgets.Layout(
            width='300px',
            margin='0 20px 0 0'
        ))

        # Create container for status display
        self.status_area = widgets.Output(
            layout=widgets.Layout(
                width='300px',
                margin='0 0 0 20px'
            )
        )

        display(upload_widgets,
                self.status_area)

        # Register observer
        self.merged_uploader.observe(self._on_merged_upload_change, names='value')

    def _validate_and_clean_data(self, df):
        """
        Validate and clean the uploaded data, preserving numeric data even if stored as strings.
        Returns tuple of (cleaned_df, warnings, errors)
        """
        warnings = []
        errors = []
        cleaned_df = df.copy()
    
        # Check required columns exist
        required_columns = [
            'Master Protein Accessions', 
            'unique ID'
        ]
        
        # Check that at least one Avg_ column exists
        avg_columns = [col for col in df.columns if col.startswith('Avg_')]
        if not avg_columns:
            errors.append("No columns starting with 'Avg_' found in the data")
            return None, warnings, errors
            
        # Add Avg_ columns to required columns
        required_columns.extend(avg_columns)
        
        missing = set(required_columns) - set(df.columns)
        if missing:
            errors.append(f"Missing required columns: {', '.join(missing)}")
            return None, warnings, errors
    
        # Separate numeric and non-numeric columns
        numeric_columns = avg_columns  # Avg_ columns should be numeric
        text_columns = ['Master Protein Accessions', 'unique ID']
    
        # Handle blank values differently for numeric vs text columns
        for column in required_columns:
            if column in numeric_columns:
                # For numeric columns, try to convert to numeric first
                try:
                    # Convert to numeric, coerce errors to NaN
                    cleaned_df[column] = pd.to_numeric(cleaned_df[column], errors='coerce')
                    blank_count = cleaned_df[column].isna().sum()
                    if blank_count > 0:
                        warnings.append(f"Found {blank_count} invalid/blank numeric values in {column} column")
                except Exception as e:
                    errors.append(f"Error converting {column} to numeric: {str(e)}")
                    return None, warnings, errors
            elif column in text_columns:
                # For text columns, check for truly empty values
                blank_mask = cleaned_df[column].isna() | (cleaned_df[column].astype(str).str.strip() == '')
                blank_count = blank_mask.sum()
                if blank_count > 0:
                    warnings.append(f"Dropping {blank_count} rows with blank values in {column} column")
                    cleaned_df = cleaned_df[~blank_mask]
    
        # Check for invalid characters in non-blank rows
        if len(cleaned_df) > 0:
            # Check Positions in Proteins
            invalid_pos = cleaned_df['Positions in Proteins'].apply(
                lambda x: ',' in str(x) or ':' in str(x)
            )
            if invalid_pos.any():
                errors.append(
                    "Found invalid characters (',' or ':') in Positions in Proteins column. "
                    "Please update the file and upload again."
                )
            
            # Check Master Protein Accessions
            invalid_acc = cleaned_df['Master Protein Accessions'].apply(
                lambda x: ',' in str(x) or ':' in str(x)
            )
            if invalid_acc.any():
                errors.append(
                    "Found invalid characters (',' or ':') in Master Protein Accessions column. "
                    "Please update the file and upload again."
                )
    
        return cleaned_df, warnings, errors

    def _process_protein_info(self, df):
        """
        Process protein information from the dataframe and store in proteins_dic.
        Only shows progress when fetching from UniProt.
        """
        # Check if we need to fetch any data from UniProt
        has_protein_info = all(col in df.columns for col in ['protein_name', 'protein_species'])
        if has_protein_info:
            # Check if we have valid data for all entries
            all_data_present = (
                df['protein_name'].notna().all() and 
                df['protein_species'].notna().all() and
                (df['protein_name'] != '').all() and
                (df['protein_species'] != '').all()
            )
            if all_data_present:
                # If we have all data, just process it silently
                protein_info = df.groupby('Master Protein Accessions').agg({
                    'protein_name': 'first',
                    'protein_species': 'first'
                }).reset_index()
                
                for _, row in protein_info.iterrows():
                    protein_id = row['Master Protein Accessions']
                    self.proteins_dic[protein_id] = {
                        "name": row['protein_name'],
                        "species": row['protein_species']
                    }
                return len(self.proteins_dic)

        # If we need to fetch data, show progress
        progress_output = widgets.Output()
        display(progress_output)
        
        # Initialize counters
        total_proteins = 0
        uniprot_found = 0
        uniprot_not_found = 0
        multiple_entries = 0
        
        # Group by protein accession to get unique proteins
        protein_info = df.groupby('Master Protein Accessions').agg({
            'protein_name': 'first' if 'protein_name' in df.columns else lambda x: None,
            'protein_species': 'first' if 'protein_species' in df.columns else lambda x: None
        }).reset_index()

        progress_html = """
            <style>
                .fetch-status { font-family: monospace; margin: 10px 0; padding: 10px; }
                .fetch-progress { margin: 5px 0; padding: 5px; }
                .success { color: #28a745; }
                .warning { color: #ffc107; }
                .error { color: #dc3545; }
                .summary { margin-top: 10px; padding: 10px;}
            </style>
            <div class="fetch-status">
                <div id="progress-updates"></div>
            </div>
        """
        
        for _, row in protein_info.iterrows():
            total_proteins += 1
            protein_id = row['Master Protein Accessions']
            needs_fetch = True

            # Skip entries with multiple protein IDs
            if ';' in protein_id:
                multiple_entries += 1
                self.proteins_dic[protein_id] = {
                    "name": protein_id,
                    "species": "Multiple"
                }
                continue
            
            # Use existing data if available and not empty
            if (has_protein_info and 
                pd.notna(row['protein_name']) and 
                pd.notna(row['protein_species']) and 
                row['protein_name'] != '' and 
                row['protein_species'] != ''):
                self.proteins_dic[protein_id] = {
                    "name": row['protein_name'],
                    "species": row['protein_species']
                }
                needs_fetch = False
                continue
            
            # Only fetch and show progress for missing data
            if needs_fetch:
                try:
                    name, species = self.fetch_uniprot_info(protein_id)
                    
                    if name and species:
                        uniprot_found += 1
                        self.proteins_dic[protein_id] = {
                            "name": name,
                            "species": species
                        }
                        status = f'<div class="fetch-progress success">Found UniProt data for: {protein_id} → {name} ({species})</div>'
                    else:
                        uniprot_not_found += 1
                        self.proteins_dic[protein_id] = {
                            "name": protein_id,
                            "species": "Unknown"
                        }
                        status = f'<div class="fetch-progress error">No UniProt data found for: {protein_id}</div>'
                        
                except Exception as e:
                    uniprot_not_found += 1
                    self.proteins_dic[protein_id] = {
                        "name": protein_id,
                        "species": "Unknown"
                    }
                    status = f'<div class="fetch-progress error">Error processing {protein_id}: {str(e)}</div>'

                # Update progress display only for fetched proteins
                with progress_output:
                    progress_output.clear_output(wait=True)
                    display(HTML(progress_html + status + f"""
                        <div class="summary">
                            <h4>Fetch Progress:</h4>
                            <ul>
                                <li>UniProt matches found: {uniprot_found}</li>
                                <li>UniProt matches not found: {uniprot_not_found}</li>
                                <li>Multiple entry proteins: {multiple_entries}</li>
                            </ul>
                        </div>
                    """))
        
        # Only show final summary if we had to fetch any data
        if uniprot_found + uniprot_not_found > 0:
            with progress_output:
                progress_output.clear_output(wait=True)
                display(HTML(f"""
                    <div class="fetch-status">
                        <h4 style="color:green;"><b>UniProt Fetch Complete!</b></h4>
                        <div class="summary">
                            <h4>Final Summary:</h4>
                            <ul>
                                <li>UniProt matches found: {uniprot_found}</li>
                                <li>UniProt matches not found: {uniprot_not_found}</li>
                                <li>Multiple entry proteins: {multiple_entries}</li>
                            </ul>
                        </div>
                    </div>
                """))
        
        return len(self.proteins_dic)
            
    def _load_merged_data(self, file_data):
        """
        Load and validate merged data file
        Returns tuple of (dataframe, status)
        """
        try:
            content = bytes(file_data.content)
            filename = file_data.name
            extension = filename.split('.')[-1].lower()

            file_stream = io.BytesIO(content)

            # Load data based on file extension
            try:
                if extension == 'csv':
                    df = pd.read_csv(file_stream)
                elif extension in ['txt', 'tsv']:
                    df = pd.read_csv(file_stream, delimiter='\t')
                elif extension == 'xlsx':
                    df = pd.read_excel(file_stream)
                else:
                    display(HTML(f'<b style="color:red;">Error: Unsupported file format</b>'))
                    return None, 'no'
            except Exception as e:
                display(HTML(f'<b style="color:red;">Error reading file: {str(e)}</b>'))
                return None, 'no'

            # Check for protein info columns and notify user
            missing_columns = []
            if 'protein_name' not in df.columns:
                missing_columns.append('protein_name')
                df['protein_name'] = ''
            if 'protein_species' not in df.columns:
                missing_columns.append('protein_species')
                df['protein_species'] = ''
                
            if missing_columns:
                notification = f"""
                <div style="padding: 10px; margin: 10px 0;">
                    <p style="color: #17a2b8; margin: 0;">
                        <b>Notice:</b> The following columns are missing from your data:
                        <ul style="color: #17a2b8; margin: 5px 0;">
                            {''.join(f'<li>{col}</li>' for col in missing_columns)}
                        </ul>
                        </p>
                        <p style="color: #17a2b8; margin: 0;">
                        UniProt will be searched to automatically fill in this information. <br>
                        Alternativly you can upload a standardized file from the data transomation module with the protein information. 
                    </p>
                </div>
                """
                display(HTML(notification))

            # Validate and clean data
            cleaned_df, warnings, errors = self._validate_and_clean_data(df)

            # Warnings about invalid/blank values are commented out
            # if warnings:
            #     warning_html = "<br>".join([
            #         f'<b style="color:orange;">Warning: {w}</b>'
            #         for w in warnings
            #     ])
            #     display(HTML(warning_html))

            # Display errors if any
            if errors:
                error_html = "<br>".join([
                    f'<b style="color:red;">Error: {e}</b>'
                    for e in errors
                ])
                display(HTML(error_html))
                return None, 'no'

            if cleaned_df is not None and len(cleaned_df) > 0:
                # Process protein information
                num_proteins = self._process_protein_info(cleaned_df)
                
                # Add information about remaining rows and processed proteins
                success_message = f"""
                <div style="padding: 10px; margin: 10px 0; border-left: 4px solid #28a745; background-color: #f8f9fa;">
                    <p style="color: #28a745; margin: 0;">
                        <b>Data Import Complete!</b><br>
                        • Data imported successfully with {cleaned_df.shape[0]} rows and {cleaned_df.shape[1]} columns.<br>
                        • Processed data contains {len(cleaned_df)} rows after removing blank values.<br>
                        • Successfully processed information for {num_proteins} unique proteins.
                    </p>
                </div>
                """
                #display(HTML(success_message))
                return cleaned_df, 'yes'
            else:
                display(HTML('<b style="color:red;">Error: No valid data rows remaining after cleaning</b>'))
                return None, 'no'

        except Exception as e:
            display(HTML(f'<b style="color:red;">Error processing file: {str(e)}</b>'))
            return None, 'no'

    def _on_merged_upload_change(self, change):
        """Handle merged data file upload"""
        if change['type'] == 'change' and change['name'] == 'value':
            with self.output_area:
                self.output_area.clear_output()
                if change['new'] and len(change['new']) > 0:
                    file_data = change['new'][0]
                    df, status = self._load_merged_data(file_data)
                    if status == 'yes' and df is not None:
                        self.merged_df = df  # Only set merged_df if validation passed
                        display(HTML(
                            f'<b style="color:green;">Data imported successfully with '
                            f'{df.shape[0]} rows and {df.shape[1]} columns.</b>'
                        ))

    def fetch_uniprot_info(self, protein_id):
        """
        Fetch protein information from UniProt, prioritizing common names.
        Returns tuple of (protein_common_name, species_common_name) or (None, None) if not found.
        """
        try:
            # Try REST API first
            rest_url = f'https://rest.uniprot.org/uniprotkb/{protein_id}.json'
            response = requests.get(rest_url)
            
            if response.status_code == 200:
                data = response.json()
                
                # Get protein common name
                try:
                    # Look for protein names
                    names = data['proteinDescription']
                    protein_name = None
                    
                    # Try to find a common/short name
                    if 'shortNames' in names.get('recommendedName', {}):
                        protein_name = names['recommendedName']['shortNames'][0]['value']
                    elif 'shortNames' in names.get('alternativeNames', [{}])[0]:
                        protein_name = names['alternativeNames'][0]['shortNames'][0]['value']
                    else:
                        # Fallback to full name
                        protein_name = names.get('recommendedName', {}).get('fullName', {}).get('value')
                    
                    # Get species common name
                    organism_data = data['organism']
                    species = None
                    for name in organism_data.get('names', []):
                        if name['type'] == 'common':
                            species = name['value']
                            break
                    if not species:  # Fallback to scientific name
                        species = organism_data.get('scientificName')
                    
                    return protein_name, species
                    
                except KeyError:
                    pass  # Fall through to XML approach
            
            # Fall back to XML API
            xml_url = f'https://www.uniprot.org/uniprot/{protein_id}.xml'
            response = requests.get(xml_url)
            
            if response.status_code != 200:
                return None, None
                
            root = ElementTree.fromstring(response.content)
            ns = {'up': 'http://uniprot.org/uniprot'}
            
            # Get protein common name with fallbacks
            protein_name = None
            # Try short name first
            name_element = (
                root.find('.//up:recommendedName/up:shortName', ns) or
                root.find('.//up:alternativeName/up:shortName', ns) or
                root.find('.//up:recommendedName/up:fullName', ns) or
                root.find('.//up:submittedName/up:fullName', ns)
            )
            protein_name = name_element.text if name_element is not None else None
            
            # Get species common name
            species = None
            # Try common name first
            organism = root.find('.//up:organism/up:name[@type="common"]', ns)
            if organism is not None:
                species = organism.text
            else:
                # Fallback to scientific name
                organism = root.find('.//up:organism/up:name[@type="scientific"]', ns)
                species = organism.text if organism is not None else "Unknown"
            
            if protein_name:
                # Clean up protein name - remove any "precursor" or similar suffixes
                protein_name = protein_name.split(' precursor')[0].split(' (')[0]
                return protein_name, species
            else:
                return None, None
            
        except Exception as e:
            print(f"Error fetching UniProt data for {protein_id}: {str(e)}")
            return None, None




In [72]:
class PlotState:
    def __init__(self):
        self.current_state = {
            'uploadedData': None,
            'topProteins': None,
            'groupSelection': None,
            'xLabel': None,
            'yLabel': None,
            'colorScheme': None,
            'lastGenerated': None,
            'buttonsLocked': True
        }
    
    def update_state(self, **kwargs):
        self.current_state.update(kwargs)
        # Lock buttons when state changes
        self.current_state['buttonsLocked'] = True
    
    def generate_completed(self):
        self.current_state['buttonsLocked'] = False
        self.current_state['lastGenerated'] = {
            key: value for key, value in self.current_state.items() 
            if key not in ['lastGenerated', 'buttonsLocked']
        }
    
    def get_state(self):
        return self.current_state


In [73]:
class ProteinPlotter:
    def __init__(self, data_transformer):
        self.data_transformer = data_transformer
        self.plot_output = widgets.Output()
        self.info_output = widgets.Output()
        self.export_output = widgets.Output()
        self.proteins_df = None
        self.sum_df = None
        
        # Initialize state manager first
        self.state_manager = PlotState()
    
        # Create widgets for protein plotting
        self.num_proteins_widget = widgets.IntSlider(
            value=10,
            min=1,
            max=50,
            step=1,
            description='Top Number of Proteins:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px')
        )
        
        # Create widgets for protein plotting
        self.num_proteins_widget = widgets.IntSlider(
            value=10,
            min=1,
            max=50,
            step=1,
            description='Top Number of Proteins:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px')
        )
        
        self.download_plot_button = widgets.Button(
            description='Download Interactive Plot',
            button_style='info',
            icon='file',
            layout=widgets.Layout(width='200px')
        )

        # Create multi-select widget for groups
        self.group_select = widgets.SelectMultiple(
            options=[],
            description='Groups:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px', height='100px')
        )
        
        self.plot_button = widgets.Button(
            description='Generate/Update Data',
            button_style='success',
            icon='refresh',
            layout=widgets.Layout(width='200px')
        )
        
        self.export_button = widgets.Button(
            description='Export Data',
            button_style='info',
            icon='download',
            layout=widgets.Layout(width='200px')
        )

        # Add label customization widgets
        self.xlabel_widget = widgets.Text(
            description='X Label:',
            placeholder='Enter x-axis label',
            layout=widgets.Layout(width='300px')
        )
        
        self.ylabel_widget = widgets.Text(
            description='Y Label:',
            placeholder='Enter y-axis label',
            layout=widgets.Layout(width='300px')
        )
        self.legend_widget = widgets.Text(
            description='Legend Title',
            placeholder='Enter a custom legend title',
            layout=widgets.Layout(width='300px')
        )
        self.title_widget = widgets.Text(
            description='Plot Title',
            placeholder='Enter a custom plot title',
            layout=widgets.Layout(width='300px')
        )

        self.color_scheme = widgets.Dropdown(
            options= plotly_colors,
            value='HSV',
            description='Color Scheme:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px')
        )
        # Create layout
        self.widget_box = widgets.VBox([
            widgets.HTML("<h4>Plot Controls:</h4>"),
            self.num_proteins_widget,
            self.group_select,
            widgets.HTML("<h4>Appearance Settings:</h4>"),
            self.xlabel_widget,
            self.ylabel_widget,
            self.legend_widget,
            self.title_widget,
            self.color_scheme,
            widgets.HTML("<h4>Actions:</h4>"),
            self.plot_button,
            self.export_button,
            self.download_plot_button,
            self.info_output,
            self.plot_output,
            self.export_output
        ])
            
        # Set initial button states
        self.export_button.disabled = True
        self.download_plot_button.disabled = True
        
        def _on_input_change(change):
            self.state_manager.update_state()
            self.export_button.disabled = True
            self.download_plot_button.disabled = True
        
        # Add observers for input changes
        self.num_proteins_widget.observe(_on_input_change, names='value')
        self.group_select.observe(_on_input_change, names='value')
        self.xlabel_widget.observe(_on_input_change, names='value')
        self.ylabel_widget.observe(_on_input_change, names='value')
        self.legend_widget.observe(_on_input_change, names='value')
        self.title_widget.observe(_on_input_change, names='value')
        self.color_scheme.observe(_on_input_change, names='value')

        # Add button click handlers
        self.plot_button.on_click(self._on_plot_button_click)
        self.export_button.on_click(self._on_export_button_click)
        self.download_plot_button.on_click(self._on_download_plot_click)
        
        # Add variable to store current figure
        self.current_fig = None
        # Add observer for data changes
        self.data_transformer.merged_uploader.observe(self._update_group_options, names='value')

    def _get_avg_columns(self):
        """Get all columns that start with 'Avg_' from the merged dataframe"""
        if self.data_transformer.merged_df is not None:
            return [col for col in self.data_transformer.merged_df.columns if col.startswith('Avg_')]
        return []

    def _update_group_options(self, change):
        """Update group selection options when data changes"""
        if self.data_transformer.merged_df is not None:
            avg_columns = self._get_avg_columns()
            # Remove 'Avg_' prefix for display
            group_options = [col.replace('Avg_', '') for col in avg_columns]
            self.group_select.options = group_options
            # Select all groups by default
            self.group_select.value = group_options

    def process_data(self, selected_groups=None):
        """Process data with optional group selection"""
        if self.data_transformer.merged_df is None or not self.data_transformer.proteins_dic:
            return False

        df = self.data_transformer.merged_df.copy()
        
        # Get Absorbance columns based on selected groups
        if selected_groups:
            Absorbance_cols = [f'Avg_{var}' for var in selected_groups]
        else:
            Absorbance_cols = self._get_avg_columns()
            
        df['Total_Absorbance'] = df[Absorbance_cols].sum(axis=1).astype(int)
        
        # Filter out zero Absorbance entries
        result_df = df[['unique ID', 'Total_Absorbance']]
        result_df = result_df[result_df['Total_Absorbance'] == 0]
        all_zero_list = list(result_df['unique ID'])
        peptides_df = df[~df['unique ID'].isin(all_zero_list)]

        # Process protein positions and create proteins DataFrame
        additional_columns = ['Master Protein Accessions', 'unique ID']
        selected_columns = additional_columns + Absorbance_cols
        
        peptides_df.loc[:, 'Master Protein Accessions'] = peptides_df['Master Protein Accessions']
        
        temp_df = peptides_df.copy()
        temp_df.loc[:, 'Protein_ID'] = temp_df['Master Protein Accessions']
        
        # Create proteins DataFrame with selected columns
        self.proteins_df = temp_df.groupby('Protein_ID').agg(
            {**{col: 'first' for col in ['Master Protein Accessions']},
             **{col: 'sum' for col in Absorbance_cols}}
        ).reset_index()
        
        # Calculate relative Absorbance for selected groups
        for col in Absorbance_cols:
            col_sum = self.proteins_df[col].sum()
            if col_sum > 0:  # Avoid division by zero
                self.proteins_df[f'Rel_{col}'] = (self.proteins_df[col] / col_sum) * 100
            else:
                self.proteins_df[f'Rel_{col}'] = 0
            
        # Create sum DataFrame for selected groups
        self.sum_df = pd.DataFrame({
            'Sample': Absorbance_cols,
            'Total_Sum': [self.proteins_df[col].sum() for col in Absorbance_cols]
        })
        
        # Add protein descriptions
        name_list = []
        for _, row in self.proteins_df.iterrows():
            if ',' in row['Protein_ID']:
                strrow = row['Protein_ID'].split(',')
                named_combo = self._fetch_protein_names('; '.join(strrow))
            else:
                named_combo = self._fetch_protein_names(row['Protein_ID'])
            name_list.append(named_combo)
        
        # Drop the 'Protein_ID' column
        self.proteins_df = self.proteins_df.drop(columns=['Protein_ID'])    
        
        self.proteins_df['Description'] = name_list
        self.proteins_df['Description'] = self.proteins_df['Description'].astype(str).str.replace(r"['\['\]]", "", regex=True)
        
        return True

    def _fetch_protein_names(self, accession_str):
        """
        Fetch protein names from the proteins dictionary.
        Returns a list of protein names, using the full protein name.
        """
        names = []
        for acc in accession_str.split('; '):
            if acc in self.data_transformer.proteins_dic:
                # Use the full protein name instead of splitting it
                name = self.data_transformer.proteins_dic[acc]['name']
                names.append(name)
            else:
                names.append(acc)
        return names
    
    def plot_stacked_bar_scaled(self, pro_list, title, selected_groups):
        if self.proteins_df is None or self.sum_df is None:
            return None
                
        scaled_df = self.proteins_df.copy()
        
        # Filter for selected groups
        sample_orders = [f'Rel_Avg_{var}' for var in selected_groups]
        
        # Create sample mapping for selected groups
        sample_mapping = {
            f'Rel_Avg_{var}': f'Avg_{var}' for var in selected_groups
        }
        
        # Convert relative absorbance columns to percentages
        for col in sample_orders:
            scaled_df[col] = self.proteins_df[col] * 100
                
        # Calculate total sums for each group first
        total_sums = {}
        for col in sample_orders:
            sample_key = sample_mapping[col]
            total_sum = self.sum_df.loc[self.sum_df['Sample'] == sample_key, 'Total_Sum'].values[0]
            total_sums[col.replace('Rel_Avg_', '')] = total_sum  # Store with clean group name
        
        # Scale absolute Absorbance
        for col in sample_orders:
            sample_key = sample_mapping[col]
            total_sum = self.sum_df.loc[self.sum_df['Sample'] == sample_key, 'Total_Sum'].values[0]
            if total_sum > 0:
                scaled_df[col] = self.proteins_df[col] * total_sum / self.proteins_df[col].sum()
                
        # Sort proteins_df based on pro_list
        description_order = {desc: i for i, desc in enumerate(pro_list)}
        scaled_df['Order'] = scaled_df['Description'].map(description_order)
        scaled_df = scaled_df.sort_values(by='Order').reset_index(drop=True)
        
        # Create figure object before adding traces
        fig = go.Figure()
            
        def get_color_sequence(scheme, n_colors):
            """Get color sequence based on selected scheme."""
            try:
                if scheme.lower() in ['rainbow', 'hsv']:
                    return [f'hsl({h},70%,60%)' for h in np.linspace(0, 330, n_colors)]
                
                color_sequence = getattr(px.colors.sequential, scheme, None)
                if color_sequence is None:
                    color_sequence = getattr(px.colors.diverging, scheme, None)
                if color_sequence is None:
                    color_sequence = getattr(px.colors.cyclical, scheme, None)
                
                if color_sequence:
                    if n_colors >= len(color_sequence):
                        indices = np.linspace(0, len(color_sequence)-1, n_colors)
                        return [color_sequence[int(i)] for i in indices]
                    else:
                        indices = np.linspace(0, len(color_sequence)-1, n_colors, dtype=int)
                        return [color_sequence[i] for i in indices]
                
                return [f'hsl({h},70%,60%)' for h in np.linspace(0, 330, n_colors)]
            
            except Exception as e:
                print(f"Error generating colors: {e}")
                return [f'hsl({h},70%,60%)' for h in np.linspace(0, 330, n_colors)]
    
        # Get colors based on selected scheme
        colors = get_color_sequence(self.color_scheme.value, len(pro_list))
 
        # Calculate total absorbance for each sample
        total_absorbance = {}
        for sample in sample_orders:
            total_absorbance[sample] = scaled_df[scaled_df['Description'].isin(pro_list)][sample].sum()
        
        for idx, row in scaled_df.iterrows():
            protein_description = row['Description']
            if protein_description in pro_list:
                color = colors[pro_list.index(protein_description)]
                
                hover_text = []
                for sample in sample_orders:
                    abs_col = sample_mapping[sample]
                    rel_value_hov = self.proteins_df.loc[self.proteins_df['Description'] == protein_description, sample].values[0]
                    abs_value = row[abs_col]
                    hover_text.append(
                        f"Protein: {row['Master Protein Accessions']}<br>" +
                        f"Description: {row['Description']}<br>" +
                        f"Sample: {sample.replace('Rel_Avg_', '')}<br>" +
                        f"Relative Absorbance: {rel_value_hov:.2f}%<br>" +
                        f"Absolute Absorbance: {abs_value:.2e}<br>"
                        #f"Total Absorbance: {total_absorbance[sample]:.2e}"  # Added total bar height

                    )
                
                fig.add_trace(go.Bar(
                    name=protein_description,
                    x=[label.replace('Rel_Avg_', '') for label in sample_orders],
                    y=row[sample_orders],
                    marker_color=color,
                    hovertext=hover_text,
                    hoverinfo='text'
                ))

        # Get custom labels
        x_label = self.xlabel_widget.value or ''
        y_label = self.ylabel_widget.value or 'Scaled Absolute Absorbance'
        plot_title = self.title_widget.value or ''
        legend_title = self.legend_widget.value or 'Protein Origins:'
                        
        fig.update_layout(
            barmode='stack',
            title={
                'text': plot_title,
                'y': 0.95,
                'x': 0.5,
                'xanchor': 'center',
                'yanchor': 'top',
                'font': {"size": 18, 'color': 'black'}
            },
            xaxis_title=x_label,
            yaxis_title=y_label,
            legend_title=legend_title,
            legend={
                'yanchor': "top",
                'y': 0.95,
                'xanchor': "left",
                'x': 1.05,
                'traceorder': 'normal',
                'font': {"size": 14, 'color': 'black'},
                'bgcolor': 'rgba(255, 255, 255, 0.9)'
                #'bordercolor': 'black',
                #'borderwidth': 1
            },
            showlegend=True,
            template='plotly_white',
            height=800,
            width=1000,
            margin=dict(
                t=100,
                l=100,
                r=500,
                b=100
            ),
            hoverlabel=dict(
                bgcolor="white",
                font_size=14,
                font_family="Arial"
            ),
            xaxis=dict(
                showline=True,
                linewidth=1,
                linecolor='black',
                mirror=False
            ),
            yaxis=dict(
                showline=True,
                linewidth=1,
                linecolor='black',
                mirror=False
            )
        )
        
            
        fig.update_xaxes(
            tickangle=45,
            title_font={"size": 16},
            tickfont={"size": 14},
            tickfont_color="black",  # Black tick labels
            title_font_color="black",  # Black axis title                
        )
        
        fig.update_yaxes(
            title_font={"size": 16},
            tickfont={"size": 14},
            tickfont_color="black",  # Black tick labels
            title_font_color="black",  # Black axis title
            gridcolor="lightgray",  # Light gray grid lines
            showgrid=True,  # Show grid lines
            zeroline=False,  # Hide zero line
            exponentformat='E',
            showexponent='all',
            tickformat=".1e" 
        )
            
        # Later in the code, when adding the scatter trace for totals:
        fig.add_trace(go.Scatter(
            x=[label.replace('Rel_Avg_', '') for label in sample_orders],
            y=[total_sums[group.replace('Rel_Avg_', '')] for group in sample_orders],
            mode='text',
            text=[f"{total_sums[group.replace('Rel_Avg_', '')]:.2e}" for group in sample_orders],
            textposition='top center',
            textfont=dict(size=12, color='black'),
            showlegend=True,
            name='Show Summed Absorbance',
            hoverinfo='none',
            texttemplate='%{text}'
        ))
                    
        # Mark generation as complete
        self.state_manager.generate_completed()
        self.export_button.disabled = False
        self.download_plot_button.disabled = False
        
        return fig

    def _on_plot_button_click(self, b):
        if self.current_fig is not None:
            self.state_manager.generate_completed()
            self.export_button.disabled = False
            self.download_plot_button.disabled = False
        with self.plot_output:
            self.plot_output.clear_output(wait=True)
            
            if not self.group_select.value:
                print("Please select at least one group to plot.")
                return
            
            selected_groups = list(self.group_select.value)
            if self.process_data(selected_groups):
                # Calculate average Absorbance for sorting using only selected groups
                selected_avg_columns = [f'Avg_{var}' for var in selected_groups]
                #self.proteins_df['avg_absorbance_all'] = self.proteins_df[selected_avg_columns].mean(axis=1)
                                
                # Calculate sum of all selected columns
                total_sum =  self.proteins_df[selected_avg_columns].sum().sum()
                
                # Calculate row sums
                row_sums =  self.proteins_df[selected_avg_columns].sum(axis=1)
                
                # Calculate relative percentage contribution
                self.proteins_df['avg_absorbance_all'] = (row_sums / total_sum * 100).round(2)
                
                # Sort and get top N proteins
                self.proteins_df = self.proteins_df.sort_values('avg_absorbance_all', ascending=False)
                
                # Sort and get top N proteins
                pro_list = list(self.proteins_df ['Description'].head(self.num_proteins_widget.value))
                
                # Create and store plot
                self.current_fig = self.plot_stacked_bar_scaled(
                    pro_list=pro_list,
                    title='',
                    selected_groups=selected_groups
                )
                if self.current_fig is not None:
                    self.current_fig.show()
            else:
                print("Please upload all required files first.")
                print("Error creating plot. Please check your data.")
    
    def generate_download_link(self, content, filename, filetype='text/csv'):
        """Generate a download link for any content"""
        if isinstance(content, pd.DataFrame):
            if filetype == 'text/csv':
                content = content.to_csv(index=False)
            else:
                content = content.to_csv(index=True)
        if isinstance(content, str):
            content = content.encode()
        b64 = base64.b64encode(content).decode()
        return f"""
            <a id="download_link" href="data:{filetype};base64,{b64}" 
               download="{filename}"
               style="display: none;">
                Download {filename}
            </a>
            <script>
                document.getElementById('download_link').click();
            </script>
            """
    
    def _on_download_plot_click(self, b):
        """Handle plot download button click with automatic download"""
        if self.current_fig is not None:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            plot_filename = f'protein_plot_{timestamp}.html'
            
            with self.export_output:
                self.export_output.clear_output(wait=True)
                display(HTML(self.generate_download_link(
                    self.current_fig.to_html(),
                    plot_filename,
                    'text/html'
                )))
        else:
            print("Please generate a plot first.")
    
    def _on_export_button_click(self, b):
        """Handle data export with automatic download"""
        if self.proteins_df is not None:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            data_filename = f'protein_absorbance_analysis_{timestamp}.csv'
            
            with self.export_output:
                self.export_output.clear_output(wait=True)
                display(HTML(self.generate_download_link(
                    self.proteins_df,
                    data_filename,
                    'text/csv'
                )))
        else:
            print("Please generate the analysis first.")

    def display(self):
        """Display the protein analysis interface"""
        display(self.widget_box)



In [74]:

# Initialize the interface
data_transformer = DataTransformation()
data_transformer.setup_data_loading_ui()

# Create protein plotter
protein_plotter = ProteinPlotter(data_transformer)
protein_plotter.display()

VBox(children=(HTML(value='<h4>Upload Data File:</h4>'), HBox(children=(FileUpload(value=(), accept='.csv,.txt…

Output(layout=Layout(margin='0 0 0 20px', width='300px'))

VBox(children=(HTML(value='<h4>Plot Controls:</h4>'), IntSlider(value=10, description='Top Number of Proteins:…