In [1]:
import pandas as pd
import numpy as np
from datetime import datetime
import json, io, base64, re, os, requests, time, traceback
from IPython.display import display, HTML, clear_output

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

import ipywidgets as widgets
from xml.etree import ElementTree
# Initialize settings
import _settings as settings

# Global variables from settings
spec_translate_list = settings.SPEC_TRANSLATE_LIST
plotly_colors = settings.plotly_colors

In [2]:
class DataTransformation:
    def __init__(self):
        self.merged_df = None
        self.proteins_dic = {}
        self.output_area = None
        self.merged_uploader = None

    def create_download_link(self, file_path, label):
        """Create a download link for a file."""
        if os.path.exists(file_path):
            # Read file content and encode it as base64
            with open(file_path, 'rb') as f:
                content = f.read()
            b64_content = base64.b64encode(content).decode('utf-8')

            # Generate the download link HTML
            return widgets.HTML(f"""
                <a download="{os.path.basename(file_path)}" 
                   href="data:application/octet-stream;base64,{b64_content}" 
                   style="color: #0366d6; text-decoration: none; margin-left: 20px; font-size: 14px;">
                    {label}
                </a>
            """)
        else:
            # Show an error message if the file does not exist
            return widgets.HTML(f"""
                <span style="color: red; margin-left: 20px; font-size: 14px;">
                    File "{file_path}" not found!
                </span>
            """)

    def setup_data_loading_ui(self):
        """Initialize and display the data loading UI."""
        # Create file upload widget
        self.merged_uploader = widgets.FileUpload(
            accept='.csv,.txt,.tsv,.xlsx',
            multiple=False,
            description='Upload Merged Data File',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )

        self.output_area = widgets.Output()

        # Create upload box with example link
        merged_box = widgets.HBox([
            self.merged_uploader,
            self.create_download_link("example_merged_dataframe.csv", "Example")
        ], layout=widgets.Layout(align_items='center'))

        # Create left column with upload widgets
        upload_widgets = widgets.VBox([
            widgets.HTML("<h4>Upload Data File:</h4>"),
            merged_box,
            self.output_area
        ], layout=widgets.Layout(
            width='300px',
            margin='0 20px 0 0'
        ))

        # Create container for status display
        self.status_area = widgets.Output(
            layout=widgets.Layout(
                width='300px',
                margin='0 0 0 20px'
            )
        )

        display(upload_widgets,
                self.status_area)

        # Register observer
        self.merged_uploader.observe(self._on_merged_upload_change, names='value')

    def _validate_and_clean_data(self, df):
        """
        Validate and clean the uploaded data, preserving numeric data even if stored as strings.
        Returns tuple of (cleaned_df, warnings, errors)
        """
        warnings = []
        errors = []
        cleaned_df = df.copy()
    
        # Check required columns exist
        required_columns = [
            'Master Protein Accessions', 
            'unique ID'
        ]
        
        # Check that at least one Avg_ column exists
        avg_columns = [col for col in df.columns if col.startswith('Avg_')]
        if not avg_columns:
            errors.append("No columns starting with 'Avg_' found in the data")
            return None, warnings, errors
            
        # Add Avg_ columns to required columns
        required_columns.extend(avg_columns)
        
        missing = set(required_columns) - set(df.columns)
        if missing:
            errors.append(f"Missing required columns: {', '.join(missing)}")
            return None, warnings, errors
    
        # Separate numeric and non-numeric columns
        numeric_columns = avg_columns  # Avg_ columns should be numeric
        text_columns = ['Master Protein Accessions', 'unique ID']
    
        # Handle blank values differently for numeric vs text columns
        for column in required_columns:
            if column in numeric_columns:
                # For numeric columns, try to convert to numeric first
                try:
                    # Convert to numeric, coerce errors to NaN
                    cleaned_df[column] = pd.to_numeric(cleaned_df[column], errors='coerce')
                    blank_count = cleaned_df[column].isna().sum()
                    if blank_count > 0:
                        warnings.append(f"Found {blank_count} invalid/blank numeric values in {column} column")
                except Exception as e:
                    errors.append(f"Error converting {column} to numeric: {str(e)}")
                    return None, warnings, errors
            elif column in text_columns:
                # For text columns, check for truly empty values
                blank_mask = cleaned_df[column].isna() | (cleaned_df[column].astype(str).str.strip() == '')
                blank_count = blank_mask.sum()
                if blank_count > 0:
                    warnings.append(f"Dropping {blank_count} rows with blank values in {column} column")
                    cleaned_df = cleaned_df[~blank_mask]
    
        # Check for invalid characters in non-blank rows
        if len(cleaned_df) > 0:
            # Check Positions in Proteins
            invalid_pos = cleaned_df['Positions in Proteins'].apply(
                lambda x: ',' in str(x) or ':' in str(x)
            )
            if invalid_pos.any():
                errors.append(
                    "Found invalid characters (',' or ':') in Positions in Proteins column. "
                    "Please update the file and upload again."
                )
            
            # Check Master Protein Accessions
            invalid_acc = cleaned_df['Master Protein Accessions'].apply(
                lambda x: ',' in str(x) or ':' in str(x)
            )
            if invalid_acc.any():
                errors.append(
                    "Found invalid characters (',' or ':') in Master Protein Accessions column. "
                    "Please update the file and upload again."
                )
    
        return cleaned_df, warnings, errors

    def _process_protein_info(self, df):
        """
        Process protein information from the dataframe and store in proteins_dic.
        Asks user whether to fetch from UniProt or use accession IDs when protein info is missing.
        """
        # Initialize a cache for UniProt information to avoid redundant queries
        self.uniprot_cache = getattr(self, 'uniprot_cache', {})
        
        # Check if we need to fetch any data from UniProt
        has_protein_info = all(col in df.columns for col in ['protein_name', 'protein_species'])
        if has_protein_info:
            # Check if we have valid data for all entries
            all_data_present = (
                df['protein_name'].notna().all() and 
                df['protein_species'].notna().all() and
                (df['protein_name'] != '').all() and
                (df['protein_species'] != '').all()
            )
            if all_data_present:
                # If we have all data, just process it silently
                protein_info = df.groupby('Master Protein Accessions').agg({
                    'protein_name': 'first',
                    'protein_species': 'first'
                }).reset_index()
                
                for _, row in protein_info.iterrows():
                    protein_id = row['Master Protein Accessions']
                    self.proteins_dic[protein_id] = {
                        "name": row['protein_name'],
                        "species": row['protein_species']
                    }
                return len(self.proteins_dic)

        # Store the dataframe for later processing
        self._protein_df_to_process = df
        
        # Create a flag to track if processing is complete
        self._protein_processing_complete = False
        
        # If we need to fetch data, ask the user what they want to do
        # Display in the status area
        with self.status_area:
            self.status_area.clear_output()
            
            # Create buttons for user choice
            fetch_button = widgets.Button(
                description='Query UniProt',
                button_style='info',
                tooltip='Fetch protein names from UniProt database (may take time)',
                layout=widgets.Layout(width='250px')
            )
            
            use_accession_button = widgets.Button(
                description='Use Protein IDs',
                button_style='warning',
                tooltip='Use protein accession IDs as names without querying UniProt',
                layout=widgets.Layout(width='250px')
            )
            
            # Define button click handlers
            fetch_button.on_click(lambda b: self._process_proteins_with_choice(True))
            use_accession_button.on_click(lambda b: self._process_proteins_with_choice(False))
            
            display(HTML("""
                <div style="padding: 15px; margin: 10px 0; border-left: 4px solid #17a2b8; background-color: #f8f9fa;">
                    <h4 style="margin-top: 0;">Protein Information Missing</h4>
                    <p>Some protein names or species information is missing in your data.</p>
                    <p>Would you like to:</p>
                         <ul>
                                <li>Fetch protein names from UniProt database (may take time)</li>
                                <li>Use protein accession IDs as names without querying UniProt</li>
                        </ul>
                </div>
            """))
            display(widgets.HBox([fetch_button, use_accession_button]))
        
        # Return the current count, but processing will continue when a button is clicked
        return len(self.proteins_dic)

    def _process_proteins_with_choice(self, fetch_from_uniprot):
        """
        Process proteins based on user choice.
        This is called when the user clicks one of the choice buttons.
        """
        # Get the dataframe to process
        df = self._protein_df_to_process
        
        # Clear the status area and show processing message
        with self.status_area:
            self.status_area.clear_output()
            if fetch_from_uniprot:
                display(HTML('<div style="color: #17a2b8; padding: 10px; margin: 10px 0;">Fetching protein information from UniProt...</div>'))
            else:
                display(HTML('<div style="color: #ffc107; padding: 10px; margin: 10px 0;">Using protein accession IDs as names...</div>'))
        
        # Process proteins based on user choice
        # Use the status area for progress display
        with self.status_area:
            # Initialize counters
            total_proteins = 0
            uniprot_found = 0
            uniprot_not_found = 0
            multiple_entries = 0
            cached_proteins = 0
            
            # Check if we need to fetch any data from UniProt
            has_protein_info = all(col in df.columns for col in ['protein_name', 'protein_species'])
            
            # Group by protein accession to get unique proteins
            protein_info = df.groupby('Master Protein Accessions').agg({
                'protein_name': 'first' if 'protein_name' in df.columns else lambda x: None,
                'protein_species': 'first' if 'protein_species' in df.columns else lambda x: None
            }).reset_index()

            progress_html = """
                <style>
                    .fetch-status { font-family: monospace; margin: 10px 0; padding: 10px; }
                    .fetch-progress { margin: 5px 0; padding: 5px; }
                    .success { color: #28a745; }
                    .warning { color: #ffc107; }
                    .error { color: #dc3545; }
                    .info { color: #17a2b8; }
                    .summary { margin-top: 10px; padding: 10px;}
                </style>
                <div class="fetch-status">
                    <div id="progress-updates"></div>
                </div>
            """
            
            # First, collect all proteins that need fetching
            proteins_to_fetch = []
            
            for _, row in protein_info.iterrows():
                total_proteins += 1
                protein_id = row['Master Protein Accessions']
                
                # Skip entries with multiple protein IDs
                if ';' in protein_id:
                    multiple_entries += 1
                    self.proteins_dic[protein_id] = {
                        "name": protein_id,
                        "species": "Multiple"
                    }
                    continue
                
                # Use existing data if available and not empty
                if (has_protein_info and 
                    pd.notna(row['protein_name']) and 
                    pd.notna(row['protein_species']) and 
                    row['protein_name'] != '' and 
                    row['protein_species'] != ''):
                    self.proteins_dic[protein_id] = {
                        "name": row['protein_name'],
                        "species": row['protein_species']
                    }
                    continue
                
                # Check if we already have this protein in cache
                if protein_id in self.uniprot_cache:
                    cached_proteins += 1
                    name, species = self.uniprot_cache[protein_id]
                    self.proteins_dic[protein_id] = {
                        "name": name,
                        "species": species
                    }
                    continue
                
                # If we need to fetch and user chose to fetch from UniProt
                if fetch_from_uniprot:
                    proteins_to_fetch.append(protein_id)
                else:
                    # Use accession ID as name
                    self.proteins_dic[protein_id] = {
                        "name": protein_id,
                        "species": "Unknown"
                    }
            
            # Process proteins in batches if fetching from UniProt
            if fetch_from_uniprot and proteins_to_fetch:
                # Update progress display
                display(HTML(progress_html + f"""
                    <div class="fetch-progress info">
                        Preparing to fetch {len(proteins_to_fetch)} proteins from UniProt in batches...
                    </div>
                """))
                
                # Process in batches of 50 (adjust as needed)
                batch_size = 50
                total_batches = (len(proteins_to_fetch) + batch_size - 1) // batch_size
                
                for batch_num in range(total_batches):
                    start_idx = batch_num * batch_size
                    end_idx = min((batch_num + 1) * batch_size, len(proteins_to_fetch))
                    current_batch = proteins_to_fetch[start_idx:end_idx]
                    
                    # Update progress
                    self.status_area.clear_output(wait=True)
                    display(HTML(progress_html + f"""
                        <div class="fetch-progress info">
                            Fetching batch {batch_num + 1}/{total_batches} ({len(current_batch)} proteins)...
                        </div>
                        <div class="summary">
                            <h4>Progress:</h4>
                            <ul>
                                <li>Total proteins: {total_proteins}</li>
                                <li>Proteins from cache: {cached_proteins}</li>
                                <li>UniProt matches found: {uniprot_found}</li>
                                <li>UniProt matches not found: {uniprot_not_found}</li>
                                <li>Multiple entry proteins: {multiple_entries}</li>
                                <li>Remaining to fetch: {len(proteins_to_fetch) - start_idx}</li>
                            </ul>
                        </div>
                    """))
                    
                    # Fetch the batch using individual queries for now
                    # We'll implement batch fetching later
                    batch_results = {}
                    for protein_id in current_batch:
                        try:
                            name, species = self.fetch_uniprot_info(protein_id)
                            if name and species:
                                batch_results[protein_id] = (name, species)
                        except Exception as e:
                            print(f"Error fetching {protein_id}: {str(e)}")
                    
                    # Process the results
                    for protein_id in current_batch:
                        if protein_id in batch_results:
                            name, species = batch_results[protein_id]
                            uniprot_found += 1
                            self.proteins_dic[protein_id] = {
                                "name": name,
                                "species": species
                            }
                            # Add to cache
                            self.uniprot_cache[protein_id] = (name, species)
                        else:
                            uniprot_not_found += 1
                            self.proteins_dic[protein_id] = {
                                "name": protein_id,
                                "species": "Unknown"
                            }
                            # Cache the negative result too
                            self.uniprot_cache[protein_id] = (protein_id, "Unknown")
            
            # Show final summary
            self.status_area.clear_output(wait=True)
            display(HTML(f"""
                <div class="fetch-status">
                    <h4 style="color:green;"><b>Protein Processing Complete!</b></h4>
                    <div class="summary">
                        <h4>Final Summary:</h4>
                        <ul>
                            <li>Total proteins processed: {total_proteins}</li>
                            <li>Proteins from cache: {cached_proteins}</li>
                            <li>Multiple entry proteins: {multiple_entries}</li>
                            {"<li>UniProt matches found: " + str(uniprot_found) + "</li>" if fetch_from_uniprot else ""}
                            {"<li>UniProt matches not found: " + str(uniprot_not_found) + "</li>" if fetch_from_uniprot else ""}
                        </ul>
                    </div>
                </div>
            """))
        
        # Mark processing as complete
        self._protein_processing_complete = True
        
        # Update the merged_df with the protein information
        if hasattr(self, 'merged_df') and self.merged_df is not None:
            # Add protein_name and protein_species columns if they don't exist
            if 'protein_name' not in self.merged_df.columns:
                self.merged_df['protein_name'] = ''
            if 'protein_species' not in self.merged_df.columns:
                self.merged_df['protein_species'] = ''
            
            # Update the columns with the fetched information
            for protein_id, info in self.proteins_dic.items():
                mask = self.merged_df['Master Protein Accessions'] == protein_id
                self.merged_df.loc[mask, 'protein_name'] = info['name']
                self.merged_df.loc[mask, 'protein_species'] = info['species']
        
        # Return the number of proteins processed
        return len(self.proteins_dic)

    def fetch_uniprot_info_batch(self, protein_ids, max_retries=3, timeout=30):
        """
        Fetch protein information for multiple proteins at once using UniProt's batch API.
        Returns a dictionary mapping protein IDs to (name, species) tuples.
        """
        if not protein_ids:
            return {}
        
        results = {}
        
        try:
            # Use the batch REST API endpoint
            batch_url = 'https://rest.uniprot.org/uniprotkb/search'
            
            # Create a query with OR conditions for each protein ID
            query = ' OR '.join([f'accession:{pid}' for pid in protein_ids])
            
            # Parameters for the request
            params = {
                'query': query,
                'format': 'json',
                'fields': 'accession,protein_name,organism_name',
                'size': len(protein_ids)  # Request all results in one response
            }
            
            # Make the request with timeout and retry logic
            retry_count = 0
            while retry_count < max_retries:
                try:
                    response = requests.get(batch_url, params=params, timeout=timeout)
                    
                    # Handle rate limiting
                    if response.status_code == 429:  # Too Many Requests
                        retry_after = int(response.headers.get('Retry-After', 5))
                        print(f"Rate limited. Waiting {retry_after} seconds...")
                        time.sleep(retry_after + 1)  # Add 1 second buffer
                        retry_count += 1
                        continue
                    
                    # Break if successful
                    if response.status_code == 200:
                        break
                    
                    # Handle other errors
                    response.raise_for_status()
                    
                except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
                    retry_count += 1
                    wait_time = 2 ** retry_count  # Exponential backoff
                    print(f"Request failed: {str(e)}. Retrying in {wait_time} seconds...")
                    time.sleep(wait_time)
            
            # Process the response
            if response.status_code == 200:
                data = response.json()
                
                for entry in data.get('results', []):
                    accession = entry.get('primaryAccession')
                    
                    # Extract protein name
                    protein_name = None
                    protein_data = entry.get('proteinDescription', {})
                    
                    # Try to find a common/short name first
                    if 'recommendedName' in protein_data:
                        if 'shortNames' in protein_data['recommendedName']:
                            protein_name = protein_data['recommendedName']['shortNames'][0]['value']
                        else:
                            protein_name = protein_data['recommendedName'].get('fullName', {}).get('value')
                    
                    # Try alternative names if no recommended name found
                    if not protein_name and 'alternativeNames' in protein_data and protein_data['alternativeNames']:
                        if 'shortNames' in protein_data['alternativeNames'][0]:
                            protein_name = protein_data['alternativeNames'][0]['shortNames'][0]['value']
                        else:
                            protein_name = protein_data['alternativeNames'][0].get('fullName', {}).get('value')
                    
                    # Extract species name
                    species = None
                    organism_data = entry.get('organism', {})
                    organism_names = organism_data.get('names', [])
                    
                    # Try to find common name first
                    for name in organism_names:
                        if name['type'] == 'common':
                            species = name['value']
                            break
                    
                    # Fallback to scientific name
                    if not species and organism_names:
                        for name in organism_names:
                            if name['type'] == 'scientific':
                                species = name['value']
                                break
                    
                    # Clean up protein name if found
                    if protein_name:
                        protein_name = protein_name.split(' precursor')[0].split(' (')[0]
                    
                    # Store the results
                    if accession and (protein_name or species):
                        results[accession] = (protein_name or accession, species or "Unknown")
            
            return results
        
        except Exception as e:
            print(f"Error in batch fetch: {str(e)}")
            return {}
        
    def _load_merged_data(self, file_data):
        """
        Load and validate merged data file
        Returns tuple of (dataframe, status)
        """
        try:
            content = bytes(file_data.content)
            filename = file_data.name
            extension = filename.split('.')[-1].lower()

            file_stream = io.BytesIO(content)

            # Load data based on file extension
            try:
                if extension == 'csv':
                    df = pd.read_csv(file_stream)
                elif extension in ['txt', 'tsv']:
                    df = pd.read_csv(file_stream, delimiter='\t')
                elif extension == 'xlsx':
                    df = pd.read_excel(file_stream)
                else:
                    display(HTML(f'<b style="color:red;">Error: Unsupported file format</b>'))
                    return None, 'no'
            except Exception as e:
                display(HTML(f'<b style="color:red;">Error reading file: {str(e)}</b>'))
                return None, 'no'

            # Check for protein info columns and notify user
            missing_columns = []
            if 'protein_name' not in df.columns:
                missing_columns.append('protein_name')
                df['protein_name'] = ''
            if 'protein_species' not in df.columns:
                missing_columns.append('protein_species')
                df['protein_species'] = ''
                
            if missing_columns:
                notification = f"""
                <div style="padding: 10px; margin: 10px 0;">
                    <p style="color: #17a2b8; margin: 0;">
                        <b>Notice:</b> The following columns are missing from your data:
                        <ul style="color: #17a2b8; margin: 5px 0;">
                            {''.join(f'<li>{col}</li>' for col in missing_columns)}
                        </ul>
                        </p>
                        <p style="color: #17a2b8; margin: 0;">
                        UniProt will be searched to automatically fill in this information. <br>
                        Alternativly you can upload a standardized file from the data transomation module with the protein information. 
                    </p>
                </div>
                """
                display(HTML(notification))

            # Validate and clean data
            cleaned_df, warnings, errors = self._validate_and_clean_data(df)

            # Warnings about invalid/blank values are commented out
            # if warnings:
            #     warning_html = "<br>".join([
            #         f'<b style="color:orange;">Warning: {w}</b>'
            #         for w in warnings
            #     ])
            #     display(HTML(warning_html))

            # Display errors if any
            if errors:
                error_html = "<br>".join([
                    f'<b style="color:red;">Error: {e}</b>'
                    for e in errors
                ])
                display(HTML(error_html))
                return None, 'no'

            if cleaned_df is not None and len(cleaned_df) > 0:
                # Process protein information
                num_proteins = self._process_protein_info(cleaned_df)
                
                # Add information about remaining rows and processed proteins
                success_message = f"""
                <div style="padding: 10px; margin: 10px 0; border-left: 4px solid #28a745; background-color: #f8f9fa;">
                    <p style="color: #28a745; margin: 0;">
                        <b>Data Import Complete!</b><br>
                        • Data imported successfully with {cleaned_df.shape[0]} rows and {cleaned_df.shape[1]} columns.<br>
                        • Processed data contains {len(cleaned_df)} rows after removing blank values.<br>
                        • Successfully processed information for {num_proteins} unique proteins.
                    </p>
                </div>
                """
                #display(HTML(success_message))
                return cleaned_df, 'yes'
            else:
                display(HTML('<b style="color:red;">Error: No valid data rows remaining after cleaning</b>'))
                return None, 'no'

        except Exception as e:
            display(HTML(f'<b style="color:red;">Error processing file: {str(e)}</b>'))
            return None, 'no'

    def _on_merged_upload_change(self, change):
        """Handle merged data file upload"""
        if change['type'] == 'change' and change['name'] == 'value':
            with self.output_area:
                self.output_area.clear_output()
                if change['new'] and len(change['new']) > 0:
                    file_data = change['new'][0]
                    df, status = self._load_merged_data(file_data)
                    if status == 'yes' and df is not None:
                        self.merged_df = df  # Only set merged_df if validation passed
                        display(HTML(
                            f'<b style="color:green;">Data imported successfully with '
                            f'{df.shape[0]} rows and {df.shape[1]} columns.</b>'
                        ))

    def fetch_uniprot_info(self, protein_id):
        """
        Fetch protein information from UniProt, prioritizing common names.
        Returns tuple of (protein_common_name, species_common_name) or (None, None) if not found.
        """
        try:
            # Try REST API first
            rest_url = f'https://rest.uniprot.org/uniprotkb/{protein_id}.json'
            response = requests.get(rest_url)
            
            if response.status_code == 200:
                data = response.json()
                
                # Get protein common name
                try:
                    # Look for protein names
                    names = data['proteinDescription']
                    protein_name = None
                    
                    # Try to find a common/short name
                    if 'shortNames' in names.get('recommendedName', {}):
                        protein_name = names['recommendedName']['shortNames'][0]['value']
                    elif 'shortNames' in names.get('alternativeNames', [{}])[0]:
                        protein_name = names['alternativeNames'][0]['shortNames'][0]['value']
                    else:
                        # Fallback to full name
                        protein_name = names.get('recommendedName', {}).get('fullName', {}).get('value')
                    
                    # Get species common name
                    organism_data = data['organism']
                    species = None
                    for name in organism_data.get('names', []):
                        if name['type'] == 'common':
                            species = name['value']
                            break
                    if not species:  # Fallback to scientific name
                        species = organism_data.get('scientificName')
                    
                    return protein_name, species
                    
                except KeyError:
                    pass  # Fall through to XML approach
            
            # Fall back to XML API
            xml_url = f'https://www.uniprot.org/uniprot/{protein_id}.xml'
            response = requests.get(xml_url)
            
            if response.status_code != 200:
                return None, None
                
            root = ElementTree.fromstring(response.content)
            ns = {'up': 'http://uniprot.org/uniprot'}
            
            # Get protein common name with fallbacks
            protein_name = None
            # Try short name first
            name_element = (
                root.find('.//up:recommendedName/up:shortName', ns) or
                root.find('.//up:alternativeName/up:shortName', ns) or
                root.find('.//up:recommendedName/up:fullName', ns) or
                root.find('.//up:submittedName/up:fullName', ns)
            )
            protein_name = name_element.text if name_element is not None else None
            
            # Get species common name
            species = None
            # Try common name first
            organism = root.find('.//up:organism/up:name[@type="common"]', ns)
            if organism is not None:
                species = organism.text
            else:
                # Fallback to scientific name
                organism = root.find('.//up:organism/up:name[@type="scientific"]', ns)
                species = organism.text if organism is not None else "Unknown"
            
            if protein_name:
                # Clean up protein name - remove any "precursor" or similar suffixes
                protein_name = protein_name.split(' precursor')[0].split(' (')[0]
                return protein_name, species
            else:
                return None, None
            
        except Exception as e:
            print(f"Error fetching UniProt data for {protein_id}: {str(e)}")
            return None, None




In [3]:
class PlotState:
    def __init__(self):
        self.current_state = {
            'uploadedData': None,
            'topProteins': None,
            'groupSelection': None,
            'xLabel': None,
            'yLabel': None,
            'colorScheme': None,
            'lastGenerated': None,
            'buttonsLocked': True
        }
    
    def update_state(self, **kwargs):
        self.current_state.update(kwargs)
        # Lock buttons when state changes
        self.current_state['buttonsLocked'] = True
    
    def generate_completed(self):
        self.current_state['buttonsLocked'] = False
        self.current_state['lastGenerated'] = {
            key: value for key, value in self.current_state.items() 
            if key not in ['lastGenerated', 'buttonsLocked']
        }
    
    def get_state(self):
        return self.current_state


In [4]:
 
class ProteinPlotter:
    def __init__(self, data_transformer):
        self.data_transformer = data_transformer
        self.plot_output = widgets.Output()
        self.info_output = widgets.Output()
        self.export_output = widgets.Output()
        self.proteins_df = None
        self.sum_df = None
        
        # Initialize state manager first
        self.state_manager = PlotState()

        self.download_plot_button = widgets.Button(
            description='Download Interactive Plot',
            button_style='info',
            icon='file',
            layout=widgets.Layout(width='200px')
        )

        # Create multi-select widget for groups
        self.group_select = widgets.SelectMultiple(
            options=[],
            description='Groups:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px', height='100px')
        )
        
        self.plot_button = widgets.Button(
            description='Generate/Update Data',
            button_style='success',
            icon='refresh',
            layout=widgets.Layout(width='200px')
        )
        
        self.export_button = widgets.Button(
            description='Export Data',
            button_style='info',
            icon='download',
            layout=widgets.Layout(width='200px')
        )

        # Add label customization widgets
        self.xlabel_widget = widgets.Text(
            description='X Label:',
            placeholder='Enter x-axis label',
            layout=widgets.Layout(width='300px')
        )
        
        self.ylabel_widget = widgets.Text(
            description='Y Label:',
            placeholder='Enter y-axis label',
            layout=widgets.Layout(width='300px')
        )
        self.legend_widget = widgets.Text(
            description='Legend Title',
            placeholder='Enter a custom legend title',
            layout=widgets.Layout(width='300px')
        )
        self.title_widget = widgets.Text(
            description='Plot Title',
            placeholder='Enter a custom plot title',
            layout=widgets.Layout(width='300px')
        )


        # Update color scheme dropdown with categorized options
        color_schemes = [
            '--- DEFAULT (HSV)---',
            'HSV',  # Default option
            '--- QUALITATIVE (BEST FOR CATEGORIES) ---',
            'Plotly', 'D3', 'G10', 'T10', 'Alphabet', 
            'Set1', 'Set2', 'Set3', 'Pastel1', 'Pastel2', 'Paired',
            '--- SEQUENTIAL ---',
            'Viridis', 'Cividis', 'Inferno', 'Magma', 'Plasma',
            'Hot', 'Jet', 'Blues', 'Greens', 'Reds', 'Purples', 'Oranges',
            '--- DIVERGING ---',
            'Spectral', 'RdBu', 'RdYlBu', 'RdYlGn', 'PiYG', 'PRGn', 'BrBG', 'RdGy',
            '--- CYCLICAL ---',
            'IceFire', 'Edge', 'Twilight'
        ]
        
        self.color_scheme = widgets.Dropdown(
            options=color_schemes,
            value='HSV',  # Default value
            description='Color Scheme:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px')
        )

        # Add an inversion toggle radio button
        self.invert_plot = widgets.RadioButtons(
            description='Plot Orientation:',
            options=['By Sample', 'By Protein'],
            value='By Sample',  # Default selection
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px', height='auto'),
            disabled=False,
            indent=True  # Keeps options aligned with description instead of appearing below
        )

        # Add protein selection dropdown (initially hidden) - FIXED
        self.protein_selector = widgets.SelectMultiple(
            options=['All'],
            value=('All',),  # Single value, not a list
            description='Select proteins:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(
                width='300px',
                height='100px',
            )
        )
        
        # Selecte between relative and absolute plots
        self.metric_type = widgets.RadioButtons(
            description='Relative or Absolute:',
            options=['Relative', 'Absolute'],
            value='Relative',  # Default selection
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px', height='auto'),
            disabled=False,
            indent=True  # Keeps options aligned with description instead of appearing below
        )                   
        # Create the checkbox with improved description
        self.plot_minor_proteins = widgets.Checkbox(
            description='Show Minor Proteins',
            value=True,
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='220px')
        )

        # Use the existing create_help_icon function style
        def create_help_icon(tooltip_text):
            """Create a help icon widget with tooltip"""
            help_icon = widgets.HTML(
                value='<i class="fa fa-question-circle" style="color: #007bff;"></i>',
                layout=widgets.Layout(width='25px', margin='2px 5px')
            )
            help_icon.add_class('jupyter-widgets')
            help_icon.add_class('widget-html')
            return widgets.HTML(
                f'<div title="{tooltip_text}" style="display: inline-block;">{help_icon.value}</div>'
            )

        # Create a help icon with explanatory tooltip
        help_tooltip = "Groups all unselected proteins into a single 'Minor Proteins' category in the plot"
        minor_proteins_help = create_help_icon(help_tooltip)

        # Combine checkbox and help icon into a horizontal layout
        self.minor_proteins_row = widgets.HBox([
            self.plot_minor_proteins, 
            minor_proteins_help
        ], layout=widgets.Layout(align_items='center'))

        # Add after creating the protein_selector widget
        self._populate_protein_selector()

        self.protein_selection_box = widgets.VBox([
            self.protein_selector
        ],
        layout=widgets.Layout(
                width='320px',
                height='200px',
            ))


        # Update plot type selection to remove 'All Plots'
        self.plot_type = widgets.RadioButtons(
            options=['Stacked Bar Plots', 'Grouped Bar Plots', 'Pie Charts'],
            value='Stacked Bar Plots',
            description='Plot Type:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px')
        )
        

        # Add bar plot type selection
        self.abs_or_count = widgets.RadioButtons(
            options=['Absorbance', 'Count'],
            value='Absorbance', 
            description='Count or Abs.',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px')
        )

        # Create help icons
        def create_help_icon(tooltip_text):
            """Create a help icon widget with tooltip"""
            help_icon = widgets.HTML(
                value='<i class="fa fa-question-circle" style="color: #007bff;"></i>',
                layout=widgets.Layout(width='25px', margin='2px 5px')
            )
            help_icon.add_class('jupyter-widgets')
            help_icon.add_class('widget-html')
            return widgets.HTML(
                f'<div title="{tooltip_text}" style="display: inline-block;">{help_icon.value}</div>'
            )

        # Add help tooltips
        plot_type_help = create_help_icon("Select whether to display data as a bar plot or pie chart")
        bar_plot_type_help = create_help_icon("Choose the type of values to display in the bar plot")
        plot_orientation_help = create_help_icon("Group data by sample or by protein")

        # Combine widgets with help icons
        self.plot_type_row = widgets.HBox([self.plot_type,
                                           self.abs_or_count,
                                           self.metric_type],
                                           layout=widgets.Layout(width='500px')
                                          )

 
        # Create layout
        self.widget_box = widgets.VBox([
            widgets.HTML("<h4>Plot Controls:</h4>"),
            self.group_select,
            self.protein_selection_box,
            self.minor_proteins_row,
            self.plot_type_row,
            self.invert_plot,
            widgets.HTML("<h4>Appearance Settings:</h4>"),
            self.xlabel_widget,
            self.ylabel_widget,
            self.legend_widget,
            self.title_widget,
            self.color_scheme,
            widgets.HTML("<h4>Actions:</h4>"),
            self.plot_button,
            self.export_button,
            self.download_plot_button,
            self.info_output,
            self.plot_output,
            self.export_output
        ])
            
            
        # Set initial button states
        self.export_button.disabled = True
        self.download_plot_button.disabled = True
        
        def _on_input_change(change):
            self.state_manager.update_state()
            self.export_button.disabled = True
            self.download_plot_button.disabled = True

        
        # Add observers for input changes
        self.group_select.observe(_on_input_change, names='value')
        self.xlabel_widget.observe(_on_input_change, names='value')
        self.ylabel_widget.observe(_on_input_change, names='value')
        self.legend_widget.observe(_on_input_change, names='value')
        self.title_widget.observe(_on_input_change, names='value')
        self.color_scheme.observe(_on_input_change, names='value')

        # Add button click handlers
        self.plot_button.on_click(self._on_plot_button_click)
        self.export_button.on_click(self._on_export_button_click)
        self.download_plot_button.on_click(self._on_download_plot_click)
        
        # Add variable to store current figure
        self.current_fig = None
        # Add observer for data changes
        self.data_transformer.merged_uploader.observe(self._update_group_options, names='value')

        # Register an explicit callback to populate proteins when merged data changes
        self.data_transformer.merged_uploader.observe(self._populate_protein_selector, names='value')
        
        # In your setup_widgets method, add this line:
        self.color_scheme.observe(self._on_color_scheme_change, names='value')
  
    def _update_group_options(self, change):
        """Update group selection options when data changes"""
        if self.data_transformer.merged_df is not None:
            # Get all Avg_ columns
            avg_columns = [col.replace('Avg_', '') for col in self.data_transformer.merged_df.columns 
                         if col.startswith('Avg_')]
            
            # Update group selection options
            self.group_select.options = avg_columns
            # Select all groups by default
            self.group_select.value = avg_columns

    def _populate_protein_selector(self, change=None):
        """Populate the protein selector with proteins ordered by their relative abundance across all samples"""
        
        # Check if data_transformer is available
        if not hasattr(self, 'data_transformer') or self.data_transformer is None:
            return
            
        # Use proteins_dic (with 's') instead of protein_dic
        if not hasattr(self.data_transformer, 'proteins_dic') or not self.data_transformer.proteins_dic:
            return
            
        try:
            # Calculate protein abundance across all samples
            protein_abundance = {}
            
            if hasattr(self.data_transformer, 'merged_df') and self.data_transformer.merged_df is not None:
                df = self.data_transformer.merged_df
                
                # Find all Avg_ columns for abundance data
                abundance_cols = [col for col in df.columns if col.startswith('Avg_')]
                protein_col = 'Master Protein Accessions'
                
                if abundance_cols and protein_col in df.columns:
                    
                    # Process each row in the dataframe
                    for _, row in df.iterrows():
                        # Skip rows without protein information
                        if pd.isna(row[protein_col]) or row[protein_col] == '':
                            continue
                            
                        # Get proteins for this peptide
                        proteins = [p.strip() for p in str(row[protein_col]).split(';') if p.strip()]
                        
                        # Calculate total abundance across all samples for this peptide
                        total_abundance = 0
                        for col in abundance_cols:
                            try:
                                if pd.notna(row.get(col)):
                                    total_abundance += float(row.get(col, 0))
                            except (ValueError, TypeError) as e:
                                print(f"Error converting abundance value in column {col}: {str(e)}")
                                print(f"Value: {row.get(col)}, Type: {type(row.get(col))}")
                        
                        # If there are multiple proteins, divide the abundance equally among them
                        per_protein_abundance = total_abundance / len(proteins) if proteins else 0
                        
                        # Add to each protein's total
                        for protein in proteins:
                            if protein in protein_abundance:
                                protein_abundance[protein] += per_protein_abundance
                            else:
                                protein_abundance[protein] = per_protein_abundance
                    
            
            # Get the list of all proteins from proteins_dic
            all_proteins = list(self.data_transformer.proteins_dic.keys())
            
            # Sort proteins by abundance (highest first)
            if protein_abundance:
                # Get proteins sorted by abundance
                sorted_proteins = sorted(all_proteins, 
                                        key=lambda p: protein_abundance.get(p, 0), 
                                        reverse=True)

                # Create options with protein ID and name
                options = []
                options.append('All')  # Add 'All' option first
                
                # Add each protein with its ID, name and abundance
                for protein_id in sorted_proteins:
                    protein_info = self.data_transformer.proteins_dic.get(protein_id, {})
                    protein_name = protein_info.get('name', protein_id)
                    abundance = protein_abundance.get(protein_id, 0)

                    options.append(protein_name)

            else:
                # Fallback to alphabetical if no abundance data
                print("No abundance data available, falling back to alphabetical sorting")
                sorted_proteins = sorted(all_proteins)
                
                # Create options with protein ID and name
                options = []
                options.append('All')  # Add 'All' option first
                
                # Add each protein with its ID and name
                for protein_id in sorted_proteins:
                    protein_info = self.data_transformer.proteins_dic.get(protein_id, {})
                    protein_name = protein_info.get('name', protein_id)

                    options.append(protein_name)
            
            self.protein_selector.options = options
            
            # For SelectMultiple, value must be a tuple
            current_selection = self.protein_selector.value
            
            # If current selection is empty or invalid, default to 'All'
            valid_ids = [opt[0] for opt in options]
            if not current_selection or not all(item in valid_ids for item in current_selection):
                self.protein_selector.value = ('All',)
            
        except Exception as e:
            import traceback
            traceback.print_exc()
        
    def _get_proteins_to_plot(self):
        """Get the list of proteins to plot based on user selection"""
        try:
            # 'Select Specific Proteins'
            # Get selection from multi-select widget
            if hasattr(self, 'protein_selector'):
                selected = self.protein_selector.value
                
                # Handle 'All' selection
                if 'All' in selected:
                    # Use all proteins in the dataframe
                    if hasattr(self, 'proteins_df'):
                        self.pro_list = list(set(self.proteins_df['Description']))
                    else:
                        self.pro_list = []
                    return self.pro_list
                else:
                    # Use selected proteins
                    self.pro_list = list(selected)
                    return self.pro_list
            else:
                print("Protein selector widget not found")
                self.pro_list = []
                return []
                    
        except Exception as e:
            print(f"Error getting proteins to plot: {str(e)}")
            import traceback
            traceback.print_exc()
            
            # Set empty list on error for compatibility
            self.pro_list = []
            return []

    def _on_color_scheme_change(self, change):
        """Update plot when color scheme changes"""
        if self.current_fig is not None and hasattr(self, 'plot_button'):
            # Trigger plot update by simulating a button click
            self._on_plot_button_click(None)

    def _get_avg_columns(self):
        """Get all columns that start with 'Avg_' from the merged dataframe"""
        if self.data_transformer.merged_df is not None:
            return [col for col in self.data_transformer.merged_df.columns if col.startswith('Avg_')]
        return []

    def _update_group_options(self, change):
        """Update group selection options when data changes"""
        if self.data_transformer.merged_df is not None:
            avg_columns = self._get_avg_columns()
            # Remove 'Avg_' prefix for display
            group_options = [col.replace('Avg_', '') for col in avg_columns]
            self.group_select.options = group_options
            # Select all groups by default
            self.group_select.value = group_options

    def process_data(self, selected_groups=None):
        """Process data with optional group selection"""
        if self.data_transformer.merged_df is None or not self.data_transformer.proteins_dic:
            return False

        df = self.data_transformer.merged_df.copy()
        self.merged_df = df.copy()  # Store a reference to merged_df for later use
        
        # Get Absorbance columns based on selected groups
        if selected_groups:
            Absorbance_cols = [f'Avg_{var}' for var in selected_groups]
        else:
            Absorbance_cols = self._get_avg_columns()
            
        df['Total_Absorbance'] = df[Absorbance_cols].sum(axis=1).astype(int)
        
        # Filter out zero Absorbance entries
        result_df = df[['unique ID', 'Total_Absorbance']]
        result_df = result_df[result_df['Total_Absorbance'] == 0]
        all_zero_list = list(result_df['unique ID'])
        peptides_df = df[~df['unique ID'].isin(all_zero_list)]

        # Process protein positions and create proteins DataFrame
        additional_columns = ['Master Protein Accessions', 'unique ID']
        selected_columns = additional_columns + Absorbance_cols
        
        peptides_df.loc[:, 'Master Protein Accessions'] = peptides_df['Master Protein Accessions']
        
        temp_df = peptides_df.copy()
        temp_df.loc[:, 'Protein_ID'] = temp_df['Master Protein Accessions']
        
        # Create proteins DataFrame with selected columns
        self.proteins_df = temp_df.groupby('Protein_ID').agg(
            {**{col: 'first' for col in ['Master Protein Accessions']},
            **{col: 'sum' for col in Absorbance_cols}}
        ).reset_index()
        
        # Calculate relative Absorbance for selected groups
        for col in Absorbance_cols:
            col_sum = self.proteins_df[col].sum()
            if col_sum > 0:  # Avoid division by zero
                self.proteins_df[f'Rel_{col}'] = (self.proteins_df[col] / col_sum) * 100
            else:
                self.proteins_df[f'Rel_{col}'] = 0
            
        # Create sum DataFrame for selected groups
        self.sum_df = pd.DataFrame({
            'Sample': Absorbance_cols,
            'Total_Sum': [self.proteins_df[col].sum() for col in Absorbance_cols]
        })
        
        # Add protein descriptions
        name_list = []
        for _, row in self.proteins_df.iterrows():
            if ',' in row['Protein_ID']:
                strrow = row['Protein_ID'].split(',')
                named_combo = self._fetch_protein_names('; '.join(strrow))
            else:
                named_combo = self._fetch_protein_names(row['Protein_ID'])
            name_list.append(named_combo)
        
        # Drop the 'Protein_ID' column
        self.proteins_df = self.proteins_df.drop(columns=['Protein_ID'])    
        
        self.proteins_df['Description'] = name_list
        self.proteins_df['Description'] = self.proteins_df['Description'].astype(str).str.replace(r"['\['\]]", "", regex=True)
                                    
        # Determine counts based on merged_df and add to proteins_df
        if selected_groups and self.proteins_df is not None and df is not None:
            # Add count columns to the proteins_df (initialize with zeros)
            for group in selected_groups:
                count_col = f'Count_{group}'
                self.proteins_df[count_col] = 0
            
            # Create a mapping from accession to protein index in proteins_df
            accession_to_idx = {}
            for idx, row in self.proteins_df.iterrows():
                if 'Master Protein Accessions' in row and pd.notna(row['Master Protein Accessions']):
                    accession_to_idx[row['Master Protein Accessions']] = idx
                elif 'Accession' in row and pd.notna(row['Accession']):
                    accession_to_idx[row['Accession']] = idx
            
            # For each group, count peptides per protein
            for group in selected_groups:
                # Filter peptides that are present in this group
                group_peptides = df[df[f'Avg_{group}'] > 0]
                
                # Track which peptides have already been counted
                counted_peptides = set()
                
                # Track warning stats
                peptides_with_no_accession = 0
                peptides_with_no_id = 0
                peptides_already_counted = 0
                peptides_with_multi_accessions = set()
                peptides_with_no_protein_match = 0
                
                # Count peptides for each protein
                for _, peptide in group_peptides.iterrows():
                    if 'Master Protein Accessions' not in peptide or pd.isna(peptide['Master Protein Accessions']):
                        peptides_with_no_accession += 1
                        continue
                        
                    # Get unique peptide ID to track counting
                    peptide_id = peptide.get('unique ID', None)
                    if peptide_id is None or pd.isna(peptide_id):
                        peptides_with_no_id += 1
                        continue  # Skip if no unique ID
                    
                    # Skip if we've already counted this peptide
                    if peptide_id in counted_peptides:
                        peptides_already_counted += 1
                        continue
                    
                    accession = peptide['Master Protein Accessions']
                    found_match = False
                    proteins_with_multi_accessions = []
                    # Check if this peptide maps to multiple proteins
                    if ';' in accession:
                        proteins_with_multi_accessions.add(peptide_id)
                        accessions = [acc.strip() for acc in accession.split(';') if acc.strip()]
                        
                        # Only count for the first valid protein in the list
                        for acc in accessions:
                            if acc in accession_to_idx:
                                idx = accession_to_idx[acc]
                                count_col = f'Count_{group}'
                                self.proteins_df.at[idx, count_col] += 1
                                counted_peptides.add(peptide_id)  # Mark as counted
                                found_match = True
                                break  # Count only once
                    else:
                        # Handle direct match - only single protein
                        if accession in accession_to_idx:
                            idx = accession_to_idx[accession]
                            count_col = f'Count_{group}'
                            self.proteins_df.at[idx, count_col] += 1
                            counted_peptides.add(peptide_id)  # Mark as counted
                            found_match = True
                    
                    # Track peptides that didn't match any protein in our list
                    if not found_match:
                        peptides_with_no_protein_match += 1

                # Display warning about peptides mapping to multiple proteins
                warning_html = '<div style="color: orange; margin: 5px 0;"><b>Warning:</b> Peptide counting stats for group {0}:<br>'
                
                if peptides_with_no_accession > 0:
                    warning_html += f'• Skipped {peptides_with_no_accession} peptides with no accession<br>'
                    
                if peptides_with_no_id > 0:
                    warning_html += f'• Skipped {peptides_with_no_id} peptides with no unique ID<br>'
                    
                if peptides_already_counted > 0:
                    warning_html += f'• Skipped {peptides_already_counted} duplicate peptides (already counted)<br>'
                    
                if len(proteins_with_multi_accessions) > 0:
                    warning_html += f'• Found {len(proteins_with_multi_accessions)} peptides mapping to multiple proteins<br>'
                    warning_html += f'  (Each counted only once for the first matching protein)<br>'
                    
                if peptides_with_no_protein_match > 0:
                    warning_html += f'• {peptides_with_no_protein_match} peptides had no matching protein in the protein list<br>'
                    
                total_peptides = len(group_peptides)
                warning_html += f'• Total peptides processed: {total_peptides}, successfully counted: {len(counted_peptides)}'
                warning_html += '</div>'
                
                # Only display if there's something to report
                if (peptides_with_no_accession > 0 or peptides_with_no_id > 0 or 
                    peptides_already_counted > 0 or len(proteins_with_multi_accessions) > 0 or
                    peptides_with_no_protein_match > 0):
                    display(HTML(warning_html.format(group)))


            # Calculate relative count columns
            for group in selected_groups:
                count_col = f'Count_{group}'
                rel_count_col = f'Rel_Count_{group}'
                
                total_count = self.proteins_df[count_col].sum()
                if total_count > 0:
                    self.proteins_df[rel_count_col] = (self.proteins_df[count_col] / total_count * 100).round(2)
                else:
                    self.proteins_df[rel_count_col] = 0
        
        # Calculate average absorbance for sorting using all available groups
        if selected_groups:
            selected_avg_columns = [f'Avg_{var}' for var in selected_groups]
            
            # Calculate sum of all selected columns
            total_sum = self.proteins_df[selected_avg_columns].sum().sum()
            
            # Calculate row sums
            row_sums = self.proteins_df[selected_avg_columns].sum(axis=1)
            
            # Calculate relative percentage contribution
            self.proteins_df['avg_absorbance_all'] = (row_sums / total_sum * 100).round(2)
            
            # Sort proteins by abundance for consistent ordering
            self.proteins_df = self.proteins_df.sort_values('avg_absorbance_all', ascending=False)
            
  
            # For Count
            for group in selected_groups:
                count_col = f'Count_{group}'
                rel_count_col = f'Rel_Count_{group}'
                self.proteins_df[f'total_count_{count_col}'] = self.proteins_df[count_col].sum()
        
        return True
   
    def _fetch_protein_names(self, accession_str):
        """
        Fetch protein names from the proteins dictionary.
        Returns a list of protein names, using the full protein name.
        """
        names = []
        for acc in accession_str.split('; '):
            if acc in self.data_transformer.proteins_dic:
                # Use the full protein name instead of splitting it
                name = self.data_transformer.proteins_dic[acc]['name']
                names.append(name)
            else:
                names.append(acc)
        return names

    def _get_color_sequence(self, n_colors):
        """Get color sequence based on selected scheme."""
        if n_colors <= 0:
            return []
        
        try:
            # Get the selected color scheme
            scheme = 'HSV'  # Default scheme
            if hasattr(self, 'color_scheme') and self.color_scheme.value:
                scheme = self.color_scheme.value
            
            # Skip header options that start with '---'
            if scheme.startswith('---'):
                scheme = 'HSV'  # Default to HSV if a header is selected
            
            # Handle special cases
            if scheme.lower() in ['rainbow', 'hsv']:
                return [f'hsl({h},70%,60%)' for h in np.linspace(0, 330, n_colors)]
            
            # Try qualitative color scales first (best for categorical data)
            color_sequence = getattr(px.colors.qualitative, scheme, None)
            if color_sequence is None:
                # Try sequential color scales
                color_sequence = getattr(px.colors.sequential, scheme, None)
            if color_sequence is None:
                # Try diverging color scales
                color_sequence = getattr(px.colors.diverging, scheme, None)
            if color_sequence is None:
                # Try cyclical color scales
                color_sequence = getattr(px.colors.cyclical, scheme, None)
            
            if color_sequence:
                if n_colors >= len(color_sequence):
                    # If we need more colors than available, interpolate
                    indices = np.linspace(0, len(color_sequence)-1, n_colors)
                    return [color_sequence[int(i)] for i in indices]
                else:
                    # If we need fewer colors, take a subset
                    indices = np.linspace(0, len(color_sequence)-1, n_colors, dtype=int)
                    return [color_sequence[i] for i in indices]
            
            # Default to HSV if no matching scheme found
            return [f'hsl({h},70%,60%)' for h in np.linspace(0, 330, n_colors)]
            
        except Exception as e:
            print(f"Error generating colors: {e}")
            # Fallback to HSV
            return [f'hsl({h},70%,60%)' for h in np.linspace(0, 330, n_colors)]

    def plot_stacked_bar_scaled(self, title, selected_groups, use_count=False):
        if self.proteins_df is None or self.sum_df is None:
            return None
                
        scaled_df = self.proteins_df.copy()
        
        # Define prefix based on metric
        if use_count:
            value_prefix = "Count_"
            rel_prefix = "Rel_Count_"
            metric_name = "Peptide Count"
        else:
            value_prefix = "Avg_"
            rel_prefix = "Rel_Avg_"
            metric_name = "Summed Abundance"
        
        # Check if we're using relative metrics
        is_relative_metric = hasattr(self, 'metric_type') and 'relative' in self.metric_type.value.lower()
        
        # Determine which prefix to use based on metric type
        active_prefix = rel_prefix if is_relative_metric else value_prefix
        
        # Filter for selected groups - use the correct prefix based on metric type
        sample_orders = [f'{active_prefix}{var}' for var in selected_groups]
        
        # Create mapping from sample name to column names
        value_cols = {var: f'{value_prefix}{var}' for var in selected_groups}
        rel_cols = {var: f'{rel_prefix}{var}' for var in selected_groups}
        
        # Calculate total values for each sample (needed for hover info)
        total_absorbance = {}
        for group in selected_groups:
            col = value_cols[group]
            total_absorbance[group] = self.proteins_df[col].sum()
        
        # For relative metric, ensure all relative columns are properly calculated
        if is_relative_metric:
            for group in selected_groups:
                value_col = value_cols[group]
                rel_col = rel_cols[group]
                
                # If relative column doesn't exist, calculate it
                if rel_col not in scaled_df.columns:
                    total = scaled_df[value_col].sum()
                    if total > 0:
                        scaled_df[rel_col] = scaled_df[value_col] / total * 100
                    else:
                        scaled_df[rel_col] = 0
        
        # Calculate total sums for each group first (for absolute values)
        total_sums = {}
        for group in selected_groups:
            sample_key = value_cols[group]
            if sample_key in self.sum_df['Sample'].values:
                # Direct match
                total_sum = self.sum_df.loc[self.sum_df['Sample'] == sample_key, 'Total_Sum'].values[0]
            else:
                # For Count columns, they might not be in sum_df, so calculate from proteins_df
                total_sum = self.proteins_df[sample_key].sum()
            
            # Store with clean group name
            total_sums[group] = total_sum
        
        # Create a new DataFrame for the Minor Proteins
        minor_proteins_df = pd.DataFrame()

        # Calculate minor proteins data for each sample
        for group in selected_groups:
            # Get the value and relative columns
            value_col = value_cols[group]
            rel_col = rel_cols[group]
            
            # Sum values for proteins not in pro_list
            minor_proteins_value = scaled_df[~scaled_df['Description'].isin(self.pro_list)][value_col].sum()
            
            # Calculate relative value for minor proteins
            if is_relative_metric:
                # When using relative metrics, calculate percentage directly
                total_value = scaled_df[value_col].sum()
                minor_proteins_rel_value = minor_proteins_value / total_value * 100 if total_value > 0 else 0
            else:
                # For absolute metrics, use the pre-calculated relative column if it exists
                if rel_col in scaled_df.columns:
                    minor_proteins_rel_value = scaled_df[~scaled_df['Description'].isin(self.pro_list)][rel_col].sum()
                else:
                    # If it doesn't exist, calculate it directly
                    minor_proteins_rel_value = 0
            
            # Add to the minor proteins DataFrame
            if len(minor_proteins_df) == 0:
                minor_proteins_df = pd.DataFrame({
                    'Description': ['Minor Proteins'],
                    'Master Protein Accessions': ['Minor Proteins'],
                    rel_col: [minor_proteins_rel_value],
                    value_col: [minor_proteins_value]
                })
            else:
                minor_proteins_df[rel_col] = minor_proteins_rel_value
                minor_proteins_df[value_col] = minor_proteins_value

        # Filter scaled_df to only include proteins in pro_list
        scaled_df = scaled_df[scaled_df['Description'].isin(self.pro_list)]
        
        # Sort proteins_df based on pro_list
        description_order = {desc: i for i, desc in enumerate(self.pro_list)}
        scaled_df['Order'] = scaled_df['Description'].map(description_order)
        scaled_df = scaled_df.sort_values(by='Order').reset_index(drop=True)
                        
        # Now only append if the checkbox is checked
        if self.plot_minor_proteins.value:
            # Append minor proteins to the end of the main dataframe
            scaled_df = pd.concat([scaled_df, minor_proteins_df], ignore_index=True)
                
        # Create figure object before adding traces
        fig = go.Figure()

        if self.invert_plot.value == 'By Sample':
            # Get colors based on selected color scheme
            colors = self._get_color_sequence(len(self.pro_list))
            # Add gray color for Minor Proteins
            colors.append('#808080')  # Gray color for Minor Proteins
        
            # Filter out zero values for plotting
            for idx, row in scaled_df.iterrows():
                protein_description = row['Description']
                if protein_description in self.pro_list or protein_description == 'Minor Proteins':
                    # Use gray for Minor Proteins, otherwise use the color from the sequence
                    if protein_description == 'Minor Proteins':
                        color = '#808080'  # Gray color
                    else:
                        color = colors[self.pro_list.index(protein_description)]
                    
                    # Get values for this protein across all samples
                    values_to_plot = []
                    hover_texts = []

                    for group in selected_groups:
                        # Get appropriate column based on whether we're using relative or absolute
                        col = rel_cols[group] if is_relative_metric else value_cols[group]
                        value = row[col]
                        
                        # Add value even if zero to maintain alignment
                        values_to_plot.append(value)
                        
                        # Create hover text
                        abs_value = row[value_cols[group]]
                        rel_value = row[rel_cols[group]] if rel_cols[group] in row else 0
                        
                        # Format for hover based on metric type
                        if use_count:
                            abs_format = ",.0f"  # Integer format for counts
                        else:
                            abs_format = ".2e"   # Scientific notation for abundance
                        
                        hover_text = (
                            f"Protein: {row['Master Protein Accessions']}<br>" +
                            f"Description: {row['Description']}<br>" +
                            f"Relative {metric_name}: {rel_value:.2f}%<br>" +
                            f"Absolute {metric_name}: {abs_value:{abs_format}}<br>"
                        )
                        hover_texts.append(hover_text)
                    
                    # Add trace for this protein
                    fig.add_trace(go.Bar(
                        name=protein_description,
                        x=[group for group in selected_groups],  # Use group names directly
                        y=values_to_plot,  # Use values we collected
                        marker_color=color,
                        hovertext=hover_texts,
                        hoverinfo='text'
                    ))
        else:
            # Inverted plotting logic (proteins on x-axis)
            # Get colors based on selected color scheme
            colors = self._get_color_sequence(len(selected_groups))
            
            # For each sample, create a trace
            for i, group in enumerate(selected_groups):
                # Get the column name based on whether we're using relative or absolute metrics
                col = rel_cols[group] if is_relative_metric else value_cols[group]
                
                # Collect values and hover texts for all proteins
                values_to_plot = []
                hover_texts = []
                
                for protein in self.pro_list + (['Minor Proteins'] if 'Minor Proteins' in scaled_df['Description'].values else []):
                    # Check if protein exists
                    protein_row = scaled_df[scaled_df['Description'] == protein]
                    if protein_row.empty:
                        # If protein not found, use zero
                        values_to_plot.append(0)
                        hover_texts.append(f"Protein: {protein}<br>Sample: {group}<br>Value: 0")
                        continue
                    
                    # Get the row for this protein
                    protein_row = protein_row.iloc[0]
                    
                    # Get value from appropriate column
                    value = protein_row[col]
                    
                    # Add to lists for plotting
                    values_to_plot.append(value)
                    
                    # Format for hover based on metric type
                    if use_count:
                        abs_format = ",.0f"  # Integer format for counts
                    else:
                        abs_format = ".2e"   # Scientific notation for abundance
                    
                    # Create hover text
                    abs_value = protein_row[value_cols[group]]
                    rel_value = protein_row[rel_cols[group]] if rel_cols[group] in protein_row else 0
                    
                    hover_text = (
                        f"Protein: {protein}<br>" +
                        f"Sample: {group}<br>" +
                        f"Relative {metric_name}: {rel_value:.2f}%<br>" +
                        f"Absolute {metric_name}: {abs_value:{abs_format}}<br>"
                    )
                    hover_texts.append(hover_text)
                
                # Add trace for this sample
                fig.add_trace(go.Bar(
                    name=group,  # Use group name
                    x=self.pro_list + (['Minor Proteins'] if 'Minor Proteins' in scaled_df['Description'].values else []),
                    y=values_to_plot,
                    marker_color=colors[i],
                    hovertext=hover_texts,
                    hoverinfo='text'
                ))

        
        # Get custom labels
        x_label = self.xlabel_widget.value or ('Proteins' if self.invert_plot.value == 'By Protein' else 'Samples')
        y_label = self.ylabel_widget.value or (f"Relative {metric_name} (%)" if is_relative_metric else metric_name)
        plot_title = self.title_widget.value or title or f'Protein Distribution Analysis By {metric_name}'
        legend_title = self.legend_widget.value or ('Samples' if self.invert_plot.value == 'By Protein' else 'Protein Origins')

        fig.update_layout(
            barmode='stack',
            title={
                'text': plot_title,
                'y': 0.95,
                'x': 0.5,
                'xanchor': 'center',
                'yanchor': 'top',
                'font': {"size": 18, 'color': 'black'}
            },
            xaxis_title=x_label,
            yaxis_title=y_label,
            yaxis=dict(
                showline=True,
                gridcolor='lightgray',
                showgrid=True,
                showticklabels=True,
                linewidth=1,
                linecolor='black',
                mirror=False,
                zeroline=False,  # Don't show zero line
                range=[0, 100] if is_relative_metric else None  # Set range to [0,100] for relative metrics
            ),
            xaxis=dict(
                showline=True,
                linewidth=1,
                linecolor='black',
                mirror=False,
                tickangle=-90 if self.invert_plot.value == 'By Sample' else 45  # Adjust tick angle based on orientation
            ),
            legend_title=legend_title,
            legend={
                'yanchor': "top",
                'y': 0.95,
                'xanchor': "left",
                'x': 1.05,
                'traceorder': 'normal',
                'font': {"size": 16, 'color': 'black'},
                'bgcolor': 'rgba(255, 255, 255, 0.9)'
            },
            showlegend=True,
            template='plotly_white',
            height=820,
            width=1200,
            margin=dict(
                t=100,
                l=100,
                r=100,
                b=100
            ),
            hoverlabel=dict(
                bgcolor="white",
                font_size=14,
                font_family="Arial"
            ))
        
        fig.update_xaxes(
            tickangle=45,
            title_font={"size": 18},
            tickfont={"size": 16},
            tickfont_color="black",  # Black tick labels
            title_font_color="black",  # Black axis title                
        )
        
        # Update Y axis formatting based on metric
        if is_relative_metric:
            tick_format = ".1f"  # Format as percentage with one decimal place for relative metrics
        else:
            if use_count:
                tick_format = ""  # Regular integers for counts
            else:
                tick_format = ".1e"  # Scientific notation for abundance
            
        fig.update_yaxes(
            title_font={"size": 18},
            tickfont={"size": 16},
            tickfont_color="black",  # Black tick labels
            title_font_color="black",  # Black axis title
            gridcolor="lightgray",  # Light gray grid lines
            showgrid=True,  # Show grid lines
            zeroline=False,  # Hide zero line
            exponentformat='E',
            showexponent='all',
            tickformat=tick_format
        )
        
        # Always add scatter trace for totals, but calculate differently based on orientation
        if not is_relative_metric:  # Only show totals for absolute metrics
            if self.invert_plot.value == 'By Sample':
                # Sample-wise totals calculation
                # Format based on metric
                if use_count:
                    text_format = [f"{int(total_sums[group])}" for group in selected_groups]
                else:
                    text_format = [f"{total_sums[group]:.2e}" for group in selected_groups]
                    
                fig.add_trace(go.Scatter(
                    x=selected_groups,
                    y=[total_sums[group] for group in selected_groups],
                    mode='text',
                    text=text_format,
                    textposition='top center',
                    textfont=dict(size=12, color='black'),
                    showlegend=True,
                    name=f'Show Total {metric_name}',
                    hoverinfo='none',
                    texttemplate='%{text}'
                ))
            else:
                # Calculate protein-wise totals
                protein_sums = {}

                for protein in self.pro_list + (['Minor Proteins'] if 'Minor Proteins' in scaled_df['Description'].values else []):
                    # Get protein values across all selected groups
                    total = 0
                    for group in selected_groups:
                        # Get the value column
                        col = value_cols[group]
                        # Find protein row
                        protein_row = scaled_df[scaled_df['Description'] == protein]
                        if not protein_row.empty:
                            total += protein_row[col].values[0]
                    
                    protein_sums[protein] = total

                # Format for display based on metric
                if use_count:
                    text_format = [f"{int(value)}" for value in protein_sums.values()]
                else:
                    text_format = [f"{value:.2e}" for value in protein_sums.values()]
                
                # Add the protein totals trace
                fig.add_trace(go.Scatter(
                    x=list(protein_sums.keys()),
                    y=list(protein_sums.values()),
                    mode='text',
                    text=text_format,
                    textposition='top center',
                    textfont=dict(size=12, color='black'),
                    showlegend=True,
                    name=f'Show Total {metric_name}',
                    hoverinfo='none',
                    texttemplate='%{text}'
                ))       
            
        # Mark generation as complete
        self.state_manager.generate_completed()
        self.export_button.disabled = False
        self.download_plot_button.disabled = False
        
        return fig
    
    def create_pie_charts(self, selected_groups, use_count=False):
        """Create pie charts for protein data with pre-calculated counts or abundance"""
        try:
            # Import necessary libraries
            import plotly.graph_objects as go
            from plotly.subplots import make_subplots
            import pandas as pd
            
            # Check if we have data
            if not hasattr(self, 'proteins_df') or self.proteins_df is None or self.proteins_df.empty:
                print("No protein data available to plot")
                return None
            
            # Determine the metric name and column prefixes based on the selected metric
            if use_count:
                metric_name = "Peptide Count"
                value_prefix = "Count_"
                rel_value_prefix = "Rel_Count_"
                num_format = ",.0f"  # Integer format for counts
            else:
                metric_name = "Abundance"
                value_prefix = "Avg_"
                rel_value_prefix = "Rel_Avg_"
                num_format = ",.2e"  # Scientific notation for abundance

            # Create a working copy of the dataframe
            plot_df = self.proteins_df.copy()
            
            # Get columns to use for plotting based on metric
            value_cols = [f'{value_prefix}{group}' for group in selected_groups]
            rel_value_cols = [f'{rel_value_prefix}{group}' for group in selected_groups]
            
            # Check if required columns exist
            missing_value_cols = [col for col in value_cols if col not in plot_df.columns]
            missing_rel_cols = [col for col in rel_value_cols if col not in plot_df.columns]
            
            if missing_value_cols:
                print(f"Warning: Missing {metric_name.lower()} columns: {missing_value_cols}")
                if use_count:
                    print("Defaulting to abundance columns")
                    use_count = False
                    value_prefix = "Avg_"
                    rel_value_prefix = "Rel_Avg_"
                    metric_name = "Abundance"
                    value_cols = [f'{value_prefix}{group}' for group in selected_groups]
                    rel_value_cols = [f'{rel_value_prefix}{group}' for group in selected_groups]
                else:
                    print("Cannot generate plot with missing data columns")
                    return None
            
            # Regenerate relative columns if they don't exist
            if missing_rel_cols:
                print(f"Calculating missing relative {metric_name.lower()} columns")
                for group in selected_groups:
                    col = f'{value_prefix}{group}'
                    rel_col = f'{rel_value_prefix}{group}'
                    
                    if col in plot_df.columns and rel_col not in plot_df.columns:
                        total = plot_df[col].sum()
                        if total > 0:
                            plot_df[rel_col] = (plot_df[col] / total * 100).round(2)
                        else:
                            plot_df[rel_col] = 0
            
            # Ensure we have plot columns
            if not value_cols:
                print("No data columns found for selected groups")
                return None
                
            # Handle Minor Proteins if needed
            if hasattr(self, 'plot_minor_proteins') and self.plot_minor_proteins.value:
                # Check if minor proteins already exist in the dataframe
                if 'Minor Proteins' not in plot_df['Description'].values:
                    # Calculate minor proteins data
                    if hasattr(self, 'pro_list') and self.pro_list:
                        # Get proteins that are not in pro_list
                        non_selected = plot_df[~plot_df['Description'].isin(self.pro_list)]
                        
                        if not non_selected.empty:
                            # Create a minor proteins row with sums for each group
                            minor_row = {'Description': 'Minor Proteins', 'Master Protein Accessions': 'Minor Proteins'}
                            
                            # Calculate sums for values and relative columns
                            for col in value_cols:
                                minor_row[col] = non_selected[col].sum()
                            
                            for col in rel_value_cols:
                                if col in plot_df.columns:
                                    minor_row[col] = non_selected[col].sum()
                            
                            # Add row to dataframe
                            plot_df = pd.concat([plot_df, pd.DataFrame([minor_row])], ignore_index=True)
            
            # Filter to only selected proteins (including minor proteins if enabled)
            if hasattr(self, 'pro_list') and self.pro_list:
                if 'Minor Proteins' in plot_df['Description'].values and hasattr(self, 'plot_minor_proteins') and self.plot_minor_proteins.value:
                    # Include both selected proteins and Minor Proteins
                    plot_df = plot_df[
                        plot_df['Description'].isin(self.pro_list) | 
                        (plot_df['Description'] == 'Minor Proteins')
                    ]
                else:
                    # Only include selected proteins
                    plot_df = plot_df[plot_df['Description'].isin(self.pro_list)]
            
            # Ensure we have data after filtering
            if plot_df.empty:
                print("No data available after filtering")
                return None
            
            # Determine orientation - sync with invert_plot for consistency
            if hasattr(self, 'invert_plot'):
                orientation = self.invert_plot.value
            else:
                orientation = 'By Sample'  # Default
            
            
            # Create figure and subplots with maximum 3 columns
            if orientation == 'By Sample':
                # One pie chart per sample
                num_samples = len(value_cols)
                num_cols = min(3, num_samples)  # Maximum 3 columns
                num_rows = (num_samples + num_cols - 1) // num_cols  # Ceiling division
                
                # Create subplot titles
                subplot_titles = [col.replace(value_prefix, '') for col in value_cols]
                
                # Create figure with grid layout
                fig = make_subplots(
                    rows=num_rows, 
                    cols=num_cols,
                    specs=[[{'type': 'pie'} for _ in range(num_cols)] for _ in range(num_rows)],
                    subplot_titles=subplot_titles
                )
                
                # Get unique proteins for coloring (excluding Minor Proteins)
                major_proteins = [p for p in plot_df['Description'].unique() if p != 'Minor Proteins']
                
                # Use the existing color sequence function for major proteins
                protein_colors = self._get_color_sequence(len(major_proteins))
                
                # Create a color map, setting Minor Proteins to grey
                color_map = {protein: color for protein, color in zip(major_proteins, protein_colors)}
                if 'Minor Proteins' in plot_df['Description'].values:
                    color_map['Minor Proteins'] = '#808080'  # Grey color for minor proteins
                
                # First pie chart will set the legend for all
                first_chart = True
                
                # Create a pie chart for each sample
                for i, col in enumerate(value_cols):
                    # Calculate which row and column this chart belongs in
                    row_idx = i // num_cols + 1  # 1-based indexing for plotly
                    col_idx = i % num_cols + 1   # 1-based indexing for plotly
                    
                    rel_col = rel_value_cols[i]
                    group_name = col.replace(value_prefix, '')
                    
                    sample_data = plot_df[['Description', col, rel_col]].copy()
                    # Filter out zero values
                    sample_data = sample_data[sample_data[col] > 0]
                    
                    # Skip if no data
                    if sample_data.empty:
                        continue
                    
                    # Sort by value but ensure Minor Proteins is at the end
                    if 'Minor Proteins' in sample_data['Description'].values:
                        # Extract Minor Proteins row
                        minor_row = sample_data[sample_data['Description'] == 'Minor Proteins']
                        # Get the rest sorted by value
                        other_rows = sample_data[sample_data['Description'] != 'Minor Proteins']
                        other_rows = other_rows.sort_values(by=col, ascending=False)
                        # Combine with Minor Proteins at the end
                        sample_data = pd.concat([other_rows, minor_row], ignore_index=True)
                    else:
                        sample_data = sample_data.sort_values(by=col, ascending=False)
                    
                    # Get colors for the current sample's proteins
                    colors = [color_map[protein] for protein in sample_data['Description']]
                    
                    # Create the pie chart
                    fig.add_trace(
                        go.Pie(
                            labels=sample_data['Description'],
                            values=sample_data[col],
                            name=group_name,
                            marker_colors=colors,
                            textposition='inside',
                            textinfo='percent',
                            hovertemplate=(
                                f"Protein: %{{label}}<br>"
                                 f"{metric_name}: %{{value:{num_format}}}<br>"  # Using num_format variable
                                f"Percentage: %{{percent}}<br>"
                                f"<extra></extra>"
                            ),
                            hole=0.3,
                            showlegend=first_chart  # Only show legend for the first chart
                        ),
                        row=row_idx, col=col_idx
                    )
                    
                    # After first chart, don't show labels in legend again
                    first_chart = False
            
            else:  # 'By Protein'
                # One pie chart for each protein showing distribution across samples
                unique_proteins = plot_df['Description'].unique()
                num_proteins = len(unique_proteins)
                
                num_cols = min(3, num_proteins)  # Maximum 3 columns
                num_rows = (num_proteins + num_cols - 1) // num_cols  # Ceiling division
                
                # Create figure with grid layout
                fig = make_subplots(
                    rows=num_rows, 
                    cols=num_cols,
                    specs=[[{'type': 'pie'} for _ in range(num_cols)] for _ in range(num_rows)],
                    subplot_titles=unique_proteins
                )
                
                # Use the existing color sequence function for samples
                sample_colors = self._get_color_sequence(len(value_cols))
                
                # Create a color map for samples
                color_map = {col.replace(value_prefix, ''): color for col, color in zip(value_cols, sample_colors)}
                
                # First pie chart will set the legend for all
                first_chart = True
                
                # Create a pie chart for each protein
                for i, protein in enumerate(unique_proteins):
                    # Calculate which row and column this chart belongs in
                    row_idx = i // num_cols + 1  # 1-based indexing for plotly
                    col_idx = i % num_cols + 1   # 1-based indexing for plotly
                    
                    protein_data = plot_df[plot_df['Description'] == protein]
                    
                    if not protein_data.empty:
                        # Get values for each sample
                        values = [protein_data[col].values[0] for col in value_cols]
                        labels = [col.replace(value_prefix, '') for col in value_cols]
                        
                        # Filter out zero values
                        non_zero_indices = [j for j, val in enumerate(values) if val > 0]
                        values = [values[j] for j in non_zero_indices]
                        labels = [labels[j] for j in non_zero_indices]
                        
                        if not values:  # Skip if no non-zero values
                            continue
                        
                        # Get colors for the current protein's samples
                        colors = [color_map[label] for label in labels]
                        
                        # Create the pie chart
                        fig.add_trace(
                            go.Pie(
                                labels=labels,
                                values=values,
                                name=protein,
                                marker_colors=colors,
                                textposition='inside',
                                textinfo='percent',
                                hovertemplate=(
                                    f"Sample: %{{label}}<br>"
                                    f"{metric_name}: %{{value:{num_format}}}<br>"  # Using num_format variable
                                    f"Percentage: %{{percent}}<br>"
                                    f"<extra></extra>"
                                ),
                                hole=0.3,
                                showlegend=first_chart  # Only show legend for the first chart
                            ),
                            row=row_idx, col=col_idx
                        )
                        
                        # After first chart, don't show labels in legend again
                        first_chart = False
            
            # Update layout with adjusted height for multiple rows
            plot_title = self.title_widget.value or f"Protein Distribution - {orientation} ({metric_name})"
            legend_title = self.legend_widget.value or ('Samples' if self.invert_plot.value == 'By Protein' else 'Protein Origins')

            
            fig.update_layout(
                height=500 * num_rows,  # Scale height based on number of rows
                width=min(1400, 450 * num_cols),  # Adjust width for maximum 3 columns
                title_text=plot_title,
                title={
                    'y': 0.98,
                    'x': 0.5,
                    'xanchor': 'center',
                    'yanchor': 'top',
                    'font': {"size": 20, 'color': 'black'}
                },
                showlegend=True,  # Show the legend
                legend={
                    'title': legend_title,
                    'yanchor': "top",
                    'y': 0.99,
                    'xanchor': "left",
                    'x': 1.02,
                    'font': {"size": 12},
                    #'bgcolor': 'rgba(255, 255, 255, 0.8)',
                    #'bordercolor': 'rgba(0, 0, 0, 0.5)',
                    #'borderwidth': 1
                },
                margin=dict(t=100, b=50, l=50, r=150),  # Increased right margin for legend
                paper_bgcolor='rgba(255,255,255,1)',
                plot_bgcolor='rgba(255,255,255,1)',
                font=dict(
                    family="Arial, sans-serif",
                    size=14,
                    color="black"
                )
            )
            
            return fig
            
        except Exception as e:
            print(f"Error creating pie charts: {str(e)}")
            import traceback
            traceback.print_exc()
            return None   

    def create_grouped_bar_plot(self, title, selected_groups, use_count=False):
        """Generate interactive Plotly grouped bar plots for proteins
        
        Similar to the bioactive peptides plotting function but adapted for protein data
        
        Args:
            title: Title for the plot
            selected_groups: Groups to include in the plot
            use_count: Whether to use counts instead of values
            
        Returns:
            Plotly figure object
        """
        if self.proteins_df is None or len(self.proteins_df) == 0:
            print("No data available for plotting.")
            return None
        
        try:
            # Check plot orientation
            # Determine orientation - sync with invert_plot for consistency
            if hasattr(self, 'invert_plot'):
                orientation = self.invert_plot.value
            else:
                orientation = 'By Sample'  # Default
            
            # Create a working copy of the dataframe to avoid modifying the original
            plot_df = self.proteins_df.copy()
            
            # Handle Minor Proteins if enabled
            if hasattr(self, 'plot_minor_proteins') and self.plot_minor_proteins.value:
                # Check if minor proteins already exist in the dataframe
                if 'Minor Proteins' not in plot_df['Description'].values:
                    # Calculate minor proteins data
                    if hasattr(self, 'pro_list') and self.pro_list:
                        # Get proteins that are not in pro_list
                        non_selected = plot_df[~plot_df['Description'].isin(self.pro_list)]
                        
                        if not non_selected.empty:
                            # Create a minor proteins row with sums for each group
                            minor_row = {'Description': 'Minor Proteins', 'Master Protein Accessions': 'Minor Proteins'}
                            
                            # Calculate sums for all Average and Count columns
                            avg_cols = [col for col in plot_df.columns if col.startswith('Avg_')]
                            count_cols = [col for col in plot_df.columns if col.startswith('Count_')]
                            rel_avg_cols = [col for col in plot_df.columns if col.startswith('Rel_Avg_')]
                            rel_count_cols = [col for col in plot_df.columns if col.startswith('Rel_Count_')]
                            
                            # Sum all value columns
                            for col in avg_cols + count_cols + rel_avg_cols + rel_count_cols:
                                if col in plot_df.columns:
                                    minor_row[col] = non_selected[col].sum()
                            
                            # Add row to dataframe
                            plot_df = pd.concat([plot_df, pd.DataFrame([minor_row])], ignore_index=True)
            
            # Use the proteins from self.pro_list, plus Minor Proteins if enabled
            if hasattr(self, 'pro_list') and self.pro_list:
                proteins = list(self.pro_list)  # Create a copy of the list
                
                # Add Minor Proteins if it exists and is enabled
                if 'Minor Proteins' in plot_df['Description'].values and hasattr(self, 'plot_minor_proteins') and self.plot_minor_proteins.value:
                    proteins.append('Minor Proteins')
            else:
                # Use top proteins by threshold if available
                proteins_to_show = int(self.minor_proteins_input.value) if hasattr(self, 'minor_proteins_input') else 10
                proteins = plot_df['Description'].tolist()[:proteins_to_show]
            
            # Ensure we only have proteins that exist in our dataframe
            proteins = [p for p in proteins if p in plot_df['Description'].values]
            
            if not selected_groups or not proteins:
                print("No valid groups or proteins selected for plotting.")
                return None
                
            # Create figure
            fig = go.Figure()
            
            # Based on orientation, determine which will be the categories and which will be the bars
            if orientation == 'By Protein':
                # By Protein: Proteins on x-axis, samples as different colored bars
                categories = proteins
                bar_groups = selected_groups
                
                # Create consistent color mapping for samples
                color_sequence = self._get_color_sequence(len(selected_groups)) if hasattr(self, '_get_color_sequence') else [
                    f'hsl({i * 360 / len(selected_groups)},70%,60%)' for i in range(len(selected_groups))
                ]
                color_mapping = {group: color_sequence[i] for i, group in enumerate(selected_groups)}

            elif orientation == 'By Sample':
                # By Sample: Samples on x-axis, proteins as different colored bars
                categories = selected_groups
                bar_groups = proteins
                color_sequence = self._get_color_sequence(len(proteins)) if hasattr(self, '_get_color_sequence') else [
                    f'hsl({i * 360 / len(proteins)},70%,60%)' for i in range(len(proteins))
                ]
                color_mapping = {protein: color_sequence[i] for i, protein in enumerate(proteins)}
            
            if 'Minor Proteins' in plot_df['Description'].values:
                    color_mapping['Minor Proteins'] = '#808080'  # Grey color for minor proteins
            
            # Calculate bar positions
            n_bar_groups = len(bar_groups)
            bar_width = 0.8 / n_bar_groups  # Adjust total width of group
            
            # Get metric information
            metric = self.metric_type.value if hasattr(self, 'metric_type') else 'absolute'
            use_relative = 'relative' in metric.lower()
            
            # Calculate total values for relative plots
            total_values = {}
            if use_relative:
                if orientation == 'By Protein':
                    # By Protein: Calculate totals for each protein
                    for protein in proteins:
                        protein_row = plot_df[plot_df['Description'] == protein]
                        if protein_row.empty:
                            total_values[protein] = 0
                            continue
                        protein_row = protein_row.iloc[0]
                        if use_count:
                            total_values[protein] = sum(protein_row[f'Count_{group}'] for group in selected_groups 
                                                    if f'Count_{group}' in protein_row)
                        else:
                            total_values[protein] = sum(protein_row[f'Avg_{group}'] for group in selected_groups 
                                                    if f'Avg_{group}' in protein_row)
                elif orientation == 'By Sample':
                    # By Sample: Calculate totals for each sample
                    for group in selected_groups:
                        if use_count:
                            # Calculate total count for this group using selected proteins
                            count_col = f'Count_{group}'
                            if count_col in plot_df.columns:
                                total_values[group] = plot_df[plot_df['Description'].isin(proteins)][count_col].sum()
                            else:
                                total_values[group] = 0
                        else:
                            # Calculate total abundance for this group using selected proteins
                            avg_col = f'Avg_{group}'
                            if avg_col in plot_df.columns:
                                total_values[group] = plot_df[plot_df['Description'].isin(proteins)][avg_col].sum()
                            else:
                                total_values[group] = 0
            
            for idx, bar_group in enumerate(bar_groups):
                # Calculate x positions for this group's bars
                x_positions = [i + (idx - n_bar_groups/2 + 0.5) * bar_width for i in range(len(categories))]
                
                values = []
                hover_text = []
                
                for i, category in enumerate(categories):
                    if orientation == 'By Protein':
                        # By Protein: category = protein, bar_group = sample
                        protein = category
                        group = bar_group
                    elif orientation == 'By Sample':
                        # By Sample: category = sample, bar_group = protein
                        protein = bar_group
                        group = category
                    
                    # Get the row for this protein
                    protein_row = plot_df[plot_df['Description'] == protein]
                    if protein_row.empty:
                        # Skip if protein not found
                        values.append(0)
                        hover_text.append(f"Protein: {protein}<br>Sample: {group}<br>Value: 0")
                        continue
                        
                    protein_row = protein_row.iloc[0]
                    
                    if use_count:
                        # Get count values
                        count_col = f'Count_{group}'
                        if count_col in protein_row:
                            value = protein_row[count_col]
                            
                            if use_relative:
                                # Calculate relative counts
                                if orientation == 'By Protein':
                                    total = total_values.get(protein, 0)
                                elif orientation == 'By Sample':
                                    total = total_values.get(group, 0)
                                relative = (value / total * 100) if total > 0 else 0
                                value = relative
                                
                                # Special hover text for Minor Proteins
                                if protein == 'Minor Proteins':
                                    hover = (f"Minor Proteins<br>" +
                                            f"Sample: {group}<br>" +
                                            f"Relative Count: {value:.1f}%<br>" +
                                            f"(Count: {int(protein_row[count_col])})")
                                else:
                                    hover = (f"Protein: {protein}<br>" +
                                            f"Sample: {group}<br>" +
                                            f"Relative Count: {value:.1f}%<br>" +
                                            f"(Count: {int(protein_row[count_col])})")
                            else:
                                # Special hover text for Minor Proteins
                                if protein == 'Minor Proteins':
                                    hover = (f"Minor Proteins<br>" +
                                            f"Sample: {group}<br>" +
                                            f"Count: {int(value)}")
                                else:
                                    hover = (f"Protein: {protein}<br>" +
                                            f"Sample: {group}<br>" +
                                            f"Count: {int(value)}")
                        else:
                            value = 0
                            hover = (f"Protein: {protein}<br>" +
                                    f"Sample: {group}<br>" +
                                    f"Count: 0")
                    else:
                        # Get abundance values
                        avg_col = f'Avg_{group}'
                        if avg_col in protein_row:
                            value = protein_row[avg_col]
                            
                            if use_relative:
                                # Calculate relative abundances
                                if orientation == 'By Protein':
                                    total = total_values.get(protein, 0)
                                elif orientation == 'By Sample':
                                    total = total_values.get(group, 0)
                                relative = (value / total * 100) if total > 0 else 0
                                value = relative
                                
                                # Special hover text for Minor Proteins
                                if protein == 'Minor Proteins':
                                    hover = (f"Minor Proteins<br>" +
                                            f"Sample: {group}<br>" +
                                            f"Relative Abundance: {value:.1f}%<br>" +
                                            f"(Abundance: {protein_row[avg_col]:.2e})")
                                else:
                                    hover = (f"Protein: {protein}<br>" +
                                            f"Sample: {group}<br>" +
                                            f"Relative Abundance: {value:.1f}%<br>" +
                                            f"(Abundance: {protein_row[avg_col]:.2e})")
                            else:
                                # Special hover text for Minor Proteins
                                if protein == 'Minor Proteins':
                                    hover = (f"Minor Proteins<br>" +
                                            f"Sample: {group}<br>" +
                                            f"Absolute Abundance: {value:.2e}")
                                else:
                                    hover = (f"Protein: {protein}<br>" +
                                            f"Sample: {group}<br>" +
                                            f"Absolute Abundance: {value:.2e}")
                        else:
                            value = 0
                            hover = (f"Protein: {protein}<br>" +
                                    f"Sample: {group}<br>" +
                                    f"Abundance: 0")
                    
                    values.append(value)
                    hover_text.append(hover)
                
                # Only add trace if we have valid values
                if any(values):
                    fig.add_trace(go.Bar(
                        name=bar_group,
                        x=x_positions,
                        y=values,
                        width=bar_width * 0.9,  # Slight gap between bars
                        marker_color=color_mapping.get(bar_group, 'gray'),
                        hovertext=hover_text,
                        hoverinfo='text'
                    ))
            
            # Update layout
            plot_type_mapping = {
                (False, False): 'Absolute Abundance',
                (False, True): 'Relative Abundance',
                (True, False): 'Absolute Count',
                (True, True): 'Relative Count'
            }
            plot_type = plot_type_mapping[(use_count, use_relative)]
            
            yaxis_title = {
                'Absolute Abundance': 'Absolute Abundance',
                'Relative Abundance': 'Relative Abundance (%)',
                'Absolute Count': 'Peptide Count',
                'Relative Count': 'Relative Count (%)'
            }[plot_type]
            
            # Determine x-axis title based on orientation
            xaxis_title = 'Protein' if orientation == 'By Protein' else 'Sample Category'
            
            # Determine legend title based on orientation
            legend_title = "Sample:" if orientation == 'By Protein' else "Protein:"
            
            # Update titles from widgets if available
            if hasattr(self, 'legend_widget') and self.legend_widget.value:
                legend_title = self.legend_widget.value
            
            if hasattr(self, 'xlabel_widget') and self.xlabel_widget.value:
                xaxis_title = self.xlabel_widget.value

            if hasattr(self, 'ylabel_widget') and self.ylabel_widget.value:
                yaxis_title = self.ylabel_widget.value

            if hasattr(self, 'title_widget') and self.title_widget.value:
                title = self.title_widget.value

            fig.update_layout(
                title={
                    'text': title,
                    'y': .975,
                    'x': 0.5,
                    'xanchor': 'center',
                    'yanchor': 'top',
                    'font': {'size': 18, 'color': 'black'}
                },
                xaxis_title=xaxis_title,
                yaxis_title=yaxis_title,
                legend_title=legend_title,
                legend={'yanchor': "top", 'y': 1.0, 'xanchor': "left", 'x': 1.05, 'traceorder': 'normal', 'font': {'size': 12, 'color': 'black'}},
                showlegend=True,
                template='plotly_white',
                height=750,
                width=1100,
                margin=dict(t=100, l=100, r=200),
                hoverlabel=dict(
                    bgcolor="white",
                    font_size=12,
                    font_family="Arial"
                ),
                barmode='group',
                xaxis=dict(
                    showline=True,
                    linewidth=1,
                    linecolor='black',
                    mirror=False
                ),
                yaxis=dict(
                    showline=True,
                    linewidth=1,
                    linecolor='black',
                    mirror=False
                )
            )
            
            # Update axis properties
            fig.update_xaxes(
                ticktext=categories,
                tickvals=list(range(len(categories))),
                tickangle=45,
                title_font={"size": 18},
                tickfont={"size": 16},
                tickfont_color="black",
                title_font_color="black",
            )
            
            # Set y-axis format based on plot type
            if plot_type == 'Absolute Abundance':
                fig.update_yaxes(
                    type='log',
                    exponentformat='E',
                    showexponent='all',
                    title_font={"size": 18},
                    tickfont={"size": 16},
                    tickfont_color="black",
                    title_font_color="black",
                    gridcolor="lightgray",
                    showgrid=True,
                    zeroline=False,
                )
            elif plot_type == 'Absolute Count':
                fig.update_yaxes(
                    type='linear',
                    tickformat=",d",  # Format with commas for thousands
                    title_font={"size": 18},
                    tickfont={"size": 16},
                    tickfont_color="black",
                    title_font_color="black",
                    gridcolor="lightgray",
                    showgrid=True,
                    zeroline=False,
                )
            else:  # Relative Abundance or Relative Count
                fig.update_yaxes(
                    type='linear',
                    range=[0, 100],
                    title_font={"size": 18},
                    tickfont={"size": 16},
                    tickfont_color="black",
                    title_font_color="black",
                    gridcolor="lightgray",
                    showgrid=True,
                    zeroline=False,
                )
                
            return fig
            
        except Exception as e:
            print(f"Error creating grouped bar plot: {str(e)}")
            import traceback
            traceback.print_exc()
            return None

    def _on_plot_button_click(self, b):
        if self.current_fig is not None:
            self.state_manager.generate_completed()
            self.export_button.disabled = False
            self.download_plot_button.disabled = False
        with self.plot_output:
            self.plot_output.clear_output(wait=True)
            
            if not self.group_select.value:
                print("Please select at least one group to plot.")
                return
            
            selected_groups = list(self.group_select.value)
            if self.process_data(selected_groups):
                # Calculate average Absorbance for sorting using only selected groups
                selected_avg_columns = [f'Avg_{var}' for var in selected_groups]
                
                # Calculate sum of all selected columns
                total_sum = self.proteins_df[selected_avg_columns].sum().sum()
                
                # Calculate row sums
                row_sums = self.proteins_df[selected_avg_columns].sum(axis=1)
                
                # Calculate relative percentage contribution
                self.proteins_df['avg_absorbance_all'] = (row_sums / total_sum * 100).round(2)
                
                # Sort proteins by abundance for consistent ordering
                self.proteins_df = self.proteins_df.sort_values('avg_absorbance_all', ascending=False)
                
                # Get the list of proteins to plot
                self._get_proteins_to_plot()
                
                # Check if using count metric
                use_count = False
                metric_name = "Abundance"
                if hasattr(self, 'abs_or_count') and self.abs_or_count.value == 'Count':
                    use_count = True
                    metric_name = "Peptide Count"
                    
                    # Add minor proteins count if needed
                    if hasattr(self, 'plot_minor_proteins') and self.plot_minor_proteins.value:
                        
                        # Get non-selected proteins
                        non_selected_proteins = self.proteins_df[~self.proteins_df['Description'].isin(self.pro_list)]
                        
                        # Add count for minor proteins if there are any
                        if not non_selected_proteins.empty:
                            for group in selected_groups:
                                count_col = f'Count_{group}'
                                minor_count = non_selected_proteins[count_col].sum()
                                
                                # If minor proteins row doesn't exist, create it
                                if 'Minor Proteins' not in self.proteins_df['Description'].values:
                                    minor_row = pd.DataFrame({
                                        'Description': ['Minor Proteins'],
                                        'Master Protein Accessions': ['Minor Proteins'],
                                        count_col: [minor_count]
                                    })
                                    self.proteins_df = pd.concat([self.proteins_df, minor_row], ignore_index=True)
                                else:
                                    # Update existing minor proteins row
                                    minor_idx = self.proteins_df[self.proteins_df['Description'] == 'Minor Proteins'].index[0]
                                    self.proteins_df.at[minor_idx, count_col] = minor_count
                
                # Get the selected plot type (Bar or Pie)
                if hasattr(self, 'plot_type'):
                    plot_type = self.plot_type.value
                else:
                    plot_type = 'Stacked Bar Plots'  # Default to bar plot if widget doesn't exist
                
                # Create and display the appropriate plot based on selection and metric
                if plot_type == 'Stacked Bar Plots':
                    # Create and store bar plot (existing functionality)
                    title = f'Protein Distribution Analysis ({metric_name})'
                    
                    # Modified plot_stacked_bar_scaled to use count columns if needed
                    self.current_fig = self.plot_stacked_bar_scaled(
                        title=title,
                        selected_groups=selected_groups,
                        use_count=use_count
                    )
                    if self.current_fig is not None:
                        self.current_fig.show()
                elif plot_type == "Grouped Bar Plots":
                    # Get the group column - you might want to add a dropdown for this
                    title = f'Protein Distribution Analysis ({metric_name})'
                    self.current_fig = self.create_grouped_bar_plot(
                            title=title,
                            selected_groups=selected_groups,
                            use_count=use_count
                        )
                    if self.current_fig is not None:
                        self.current_fig.show()
                elif plot_type == 'Pie Charts':  # Pie Chart
                    # Create and store pie plot
                    try:
                        # Call the pie chart function with count flag
                        self.current_fig = self.create_pie_charts(
                            selected_groups=selected_groups, 
                            use_count=use_count
                        )
                        
                        if self.current_fig is not None:
                            self.current_fig.show()
                        else:
                            print("Error: Could not generate pie charts.")
                    except Exception as e:
                        print(f"Error creating pie charts: {str(e)}")
                        import traceback
                        traceback.print_exc()
            else:
                print("Please upload all required files first.")
                print("Error creating plot. Please check your data.")
                 
    def generate_download_link(self, content, filename, filetype='text/csv'):
        """Generate a download link for any content"""
        if isinstance(content, pd.DataFrame):
            if filetype == 'text/csv':
                content = content.to_csv(index=False)
            else:
                content = content.to_csv(index=True)
        if isinstance(content, str):
            content = content.encode()
        b64 = base64.b64encode(content).decode()
        return f"""
            <a id="download_link" href="data:{filetype};base64,{b64}" 
               download="{filename}"
               style="display: none;">
                Download {filename}
            </a>
            <script>
                document.getElementById('download_link').click();
            </script>
            """
    
    def _on_download_plot_click(self, b):
        """Handle plot download button click with automatic download"""
        if self.current_fig is not None:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            plot_filename = f'protein_plot_{timestamp}.html'
            
            with self.export_output:
                self.export_output.clear_output(wait=True)
                display(HTML(self.generate_download_link(
                    self.current_fig.to_html(),
                    plot_filename,
                    'text/html'
                )))
        else:
            print("Please generate a plot first.")
    
    def _on_export_button_click(self, b):
        """Handle data export with automatic download"""
        if self.proteins_df is not None:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            data_filename = f'protein_absorbance_analysis_{timestamp}.csv'
            
            with self.export_output:
                self.export_output.clear_output(wait=True)
                display(HTML(self.generate_download_link(
                    self.proteins_df,
                    data_filename,
                    'text/csv'
                )))
        else:
            print("Please generate the analysis first.")

    def display(self):
        """Display the protein analysis interface"""
        display(self.widget_box)



In [5]:

# Initialize the interface
data_transformer = DataTransformation()
data_transformer.setup_data_loading_ui()

# Create protein plotter
protein_plotter = ProteinPlotter(data_transformer)

protein_plotter.display()

VBox(children=(HTML(value='<h4>Upload Data File:</h4>'), HBox(children=(FileUpload(value=(), accept='.csv,.txt…

Output(layout=Layout(margin='0 0 0 20px', width='300px'))

VBox(children=(HTML(value='<h4>Plot Controls:</h4>'), SelectMultiple(description='Groups:', layout=Layout(heig…