In [1]:
import pandas as pd
import numpy as np
from datetime import datetime
import json, io, base64, re, os, requests, time
import plotly.graph_objects as go
from IPython.display import display, HTML, clear_output
import plotly.express as px
import ipywidgets as widgets
from xml.etree import ElementTree
# Initialize settings
import _settings as settings

# Global variables from settings
spec_translate_list = settings.SPEC_TRANSLATE_LIST
plotly_colors = settings.plotly_colors

In [2]:
class DataTransformation:
    def __init__(self):
        self.merged_df = None
        self.proteins_dic = {}
        self.output_area = None
        self.merged_uploader = None

    def create_download_link(self, file_path, label):
        """Create a download link for a file."""
        if os.path.exists(file_path):
            # Read file content and encode it as base64
            with open(file_path, 'rb') as f:
                content = f.read()
            b64_content = base64.b64encode(content).decode('utf-8')

            # Generate the download link HTML
            return widgets.HTML(f"""
                <a download="{os.path.basename(file_path)}" 
                   href="data:application/octet-stream;base64,{b64_content}" 
                   style="color: #0366d6; text-decoration: none; margin-left: 20px; font-size: 14px;">
                    {label}
                </a>
            """)
        else:
            # Show an error message if the file does not exist
            return widgets.HTML(f"""
                <span style="color: red; margin-left: 20px; font-size: 14px;">
                    File "{file_path}" not found!
                </span>
            """)

    def setup_data_loading_ui(self):
        """Initialize and display the data loading UI."""
        # Create file upload widget
        self.merged_uploader = widgets.FileUpload(
            accept='.csv,.txt,.tsv,.xlsx',
            multiple=False,
            description='Upload Merged Data File',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )

        self.output_area = widgets.Output()

        # Create upload box with example link
        merged_box = widgets.HBox([
            self.merged_uploader,
            self.create_download_link("example_merged_dataframe.csv", "Example")
        ], layout=widgets.Layout(align_items='center'))

        # Create left column with upload widgets
        upload_widgets = widgets.VBox([
            widgets.HTML("<h4>Upload Data File:</h4>"),
            merged_box,
            self.output_area
        ], layout=widgets.Layout(
            width='300px',
            margin='0 20px 0 0'
        ))

        # Create container for status display
        self.status_area = widgets.Output(
            layout=widgets.Layout(
                width='300px',
                margin='0 0 0 20px'
            )
        )

        display(upload_widgets,
                self.status_area)

        # Register observer
        self.merged_uploader.observe(self._on_merged_upload_change, names='value')

    def _validate_and_clean_data(self, df):
        """
        Validate and clean the uploaded data, preserving numeric data even if stored as strings.
        Returns tuple of (cleaned_df, warnings, errors)
        """
        warnings = []
        errors = []
        cleaned_df = df.copy()
    
        # Check required columns exist
        required_columns = [
            'Master Protein Accessions', 
            'unique ID'
        ]
        
        # Check that at least one Avg_ column exists
        avg_columns = [col for col in df.columns if col.startswith('Avg_')]
        if not avg_columns:
            errors.append("No columns starting with 'Avg_' found in the data")
            return None, warnings, errors
            
        # Add Avg_ columns to required columns
        required_columns.extend(avg_columns)
        
        missing = set(required_columns) - set(df.columns)
        if missing:
            errors.append(f"Missing required columns: {', '.join(missing)}")
            return None, warnings, errors
    
        # Separate numeric and non-numeric columns
        numeric_columns = avg_columns  # Avg_ columns should be numeric
        text_columns = ['Master Protein Accessions', 'unique ID']
    
        # Handle blank values differently for numeric vs text columns
        for column in required_columns:
            if column in numeric_columns:
                # For numeric columns, try to convert to numeric first
                try:
                    # Convert to numeric, coerce errors to NaN
                    cleaned_df[column] = pd.to_numeric(cleaned_df[column], errors='coerce')
                    blank_count = cleaned_df[column].isna().sum()
                    if blank_count > 0:
                        warnings.append(f"Found {blank_count} invalid/blank numeric values in {column} column")
                except Exception as e:
                    errors.append(f"Error converting {column} to numeric: {str(e)}")
                    return None, warnings, errors
            elif column in text_columns:
                # For text columns, check for truly empty values
                blank_mask = cleaned_df[column].isna() | (cleaned_df[column].astype(str).str.strip() == '')
                blank_count = blank_mask.sum()
                if blank_count > 0:
                    warnings.append(f"Dropping {blank_count} rows with blank values in {column} column")
                    cleaned_df = cleaned_df[~blank_mask]
    
        # Check for invalid characters in non-blank rows
        if len(cleaned_df) > 0:
            # Check Positions in Proteins
            invalid_pos = cleaned_df['Positions in Proteins'].apply(
                lambda x: ',' in str(x) or ':' in str(x)
            )
            if invalid_pos.any():
                errors.append(
                    "Found invalid characters (',' or ':') in Positions in Proteins column. "
                    "Please update the file and upload again."
                )
            
            # Check Master Protein Accessions
            invalid_acc = cleaned_df['Master Protein Accessions'].apply(
                lambda x: ',' in str(x) or ':' in str(x)
            )
            if invalid_acc.any():
                errors.append(
                    "Found invalid characters (',' or ':') in Master Protein Accessions column. "
                    "Please update the file and upload again."
                )
    
        return cleaned_df, warnings, errors

    def _process_protein_info(self, df):
        """
        Process protein information from the dataframe and store in proteins_dic.
        Asks user whether to fetch from UniProt or use accession IDs when protein info is missing.
        """
        # Initialize a cache for UniProt information to avoid redundant queries
        self.uniprot_cache = getattr(self, 'uniprot_cache', {})
        
        # Check if we need to fetch any data from UniProt
        has_protein_info = all(col in df.columns for col in ['protein_name', 'protein_species'])
        if has_protein_info:
            # Check if we have valid data for all entries
            all_data_present = (
                df['protein_name'].notna().all() and 
                df['protein_species'].notna().all() and
                (df['protein_name'] != '').all() and
                (df['protein_species'] != '').all()
            )
            if all_data_present:
                # If we have all data, just process it silently
                protein_info = df.groupby('Master Protein Accessions').agg({
                    'protein_name': 'first',
                    'protein_species': 'first'
                }).reset_index()
                
                for _, row in protein_info.iterrows():
                    protein_id = row['Master Protein Accessions']
                    self.proteins_dic[protein_id] = {
                        "name": row['protein_name'],
                        "species": row['protein_species']
                    }
                return len(self.proteins_dic)

        # Store the dataframe for later processing
        self._protein_df_to_process = df
        
        # Create a flag to track if processing is complete
        self._protein_processing_complete = False
        
        # If we need to fetch data, ask the user what they want to do
        # Display in the status area
        with self.status_area:
            self.status_area.clear_output()
            
            # Create buttons for user choice
            fetch_button = widgets.Button(
                description='Query UniProt',
                button_style='info',
                tooltip='Fetch protein names from UniProt database (may take time)',
                layout=widgets.Layout(width='250px')
            )
            
            use_accession_button = widgets.Button(
                description='Use Protein IDs',
                button_style='warning',
                tooltip='Use protein accession IDs as names without querying UniProt',
                layout=widgets.Layout(width='250px')
            )
            
            # Define button click handlers
            fetch_button.on_click(lambda b: self._process_proteins_with_choice(True))
            use_accession_button.on_click(lambda b: self._process_proteins_with_choice(False))
            
            display(HTML("""
                <div style="padding: 15px; margin: 10px 0; border-left: 4px solid #17a2b8; background-color: #f8f9fa;">
                    <h4 style="margin-top: 0;">Protein Information Missing</h4>
                    <p>Some protein names or species information is missing in your data.</p>
                    <p>Would you like to:</p>
                         <ul>
                                <li>Fetch protein names from UniProt database (may take time)</li>
                                <li>Use protein accession IDs as names without querying UniProt</li>
                        </ul>
                </div>
            """))
            display(widgets.HBox([fetch_button, use_accession_button]))
        
        # Return the current count, but processing will continue when a button is clicked
        return len(self.proteins_dic)

    def _process_proteins_with_choice(self, fetch_from_uniprot):
        """
        Process proteins based on user choice.
        This is called when the user clicks one of the choice buttons.
        """
        # Get the dataframe to process
        df = self._protein_df_to_process
        
        # Clear the status area and show processing message
        with self.status_area:
            self.status_area.clear_output()
            if fetch_from_uniprot:
                display(HTML('<div style="color: #17a2b8; padding: 10px; margin: 10px 0;">Fetching protein information from UniProt...</div>'))
            else:
                display(HTML('<div style="color: #ffc107; padding: 10px; margin: 10px 0;">Using protein accession IDs as names...</div>'))
        
        # Process proteins based on user choice
        # Use the status area for progress display
        with self.status_area:
            # Initialize counters
            total_proteins = 0
            uniprot_found = 0
            uniprot_not_found = 0
            multiple_entries = 0
            cached_proteins = 0
            
            # Check if we need to fetch any data from UniProt
            has_protein_info = all(col in df.columns for col in ['protein_name', 'protein_species'])
            
            # Group by protein accession to get unique proteins
            protein_info = df.groupby('Master Protein Accessions').agg({
                'protein_name': 'first' if 'protein_name' in df.columns else lambda x: None,
                'protein_species': 'first' if 'protein_species' in df.columns else lambda x: None
            }).reset_index()

            progress_html = """
                <style>
                    .fetch-status { font-family: monospace; margin: 10px 0; padding: 10px; }
                    .fetch-progress { margin: 5px 0; padding: 5px; }
                    .success { color: #28a745; }
                    .warning { color: #ffc107; }
                    .error { color: #dc3545; }
                    .info { color: #17a2b8; }
                    .summary { margin-top: 10px; padding: 10px;}
                </style>
                <div class="fetch-status">
                    <div id="progress-updates"></div>
                </div>
            """
            
            # First, collect all proteins that need fetching
            proteins_to_fetch = []
            
            for _, row in protein_info.iterrows():
                total_proteins += 1
                protein_id = row['Master Protein Accessions']
                
                # Skip entries with multiple protein IDs
                if ';' in protein_id:
                    multiple_entries += 1
                    self.proteins_dic[protein_id] = {
                        "name": protein_id,
                        "species": "Multiple"
                    }
                    continue
                
                # Use existing data if available and not empty
                if (has_protein_info and 
                    pd.notna(row['protein_name']) and 
                    pd.notna(row['protein_species']) and 
                    row['protein_name'] != '' and 
                    row['protein_species'] != ''):
                    self.proteins_dic[protein_id] = {
                        "name": row['protein_name'],
                        "species": row['protein_species']
                    }
                    continue
                
                # Check if we already have this protein in cache
                if protein_id in self.uniprot_cache:
                    cached_proteins += 1
                    name, species = self.uniprot_cache[protein_id]
                    self.proteins_dic[protein_id] = {
                        "name": name,
                        "species": species
                    }
                    continue
                
                # If we need to fetch and user chose to fetch from UniProt
                if fetch_from_uniprot:
                    proteins_to_fetch.append(protein_id)
                else:
                    # Use accession ID as name
                    self.proteins_dic[protein_id] = {
                        "name": protein_id,
                        "species": "Unknown"
                    }
            
            # Process proteins in batches if fetching from UniProt
            if fetch_from_uniprot and proteins_to_fetch:
                # Update progress display
                display(HTML(progress_html + f"""
                    <div class="fetch-progress info">
                        Preparing to fetch {len(proteins_to_fetch)} proteins from UniProt in batches...
                    </div>
                """))
                
                # Process in batches of 50 (adjust as needed)
                batch_size = 50
                total_batches = (len(proteins_to_fetch) + batch_size - 1) // batch_size
                
                for batch_num in range(total_batches):
                    start_idx = batch_num * batch_size
                    end_idx = min((batch_num + 1) * batch_size, len(proteins_to_fetch))
                    current_batch = proteins_to_fetch[start_idx:end_idx]
                    
                    # Update progress
                    self.status_area.clear_output(wait=True)
                    display(HTML(progress_html + f"""
                        <div class="fetch-progress info">
                            Fetching batch {batch_num + 1}/{total_batches} ({len(current_batch)} proteins)...
                        </div>
                        <div class="summary">
                            <h4>Progress:</h4>
                            <ul>
                                <li>Total proteins: {total_proteins}</li>
                                <li>Proteins from cache: {cached_proteins}</li>
                                <li>UniProt matches found: {uniprot_found}</li>
                                <li>UniProt matches not found: {uniprot_not_found}</li>
                                <li>Multiple entry proteins: {multiple_entries}</li>
                                <li>Remaining to fetch: {len(proteins_to_fetch) - start_idx}</li>
                            </ul>
                        </div>
                    """))
                    
                    # Fetch the batch using individual queries for now
                    # We'll implement batch fetching later
                    batch_results = {}
                    for protein_id in current_batch:
                        try:
                            name, species = self.fetch_uniprot_info(protein_id)
                            if name and species:
                                batch_results[protein_id] = (name, species)
                        except Exception as e:
                            print(f"Error fetching {protein_id}: {str(e)}")
                    
                    # Process the results
                    for protein_id in current_batch:
                        if protein_id in batch_results:
                            name, species = batch_results[protein_id]
                            uniprot_found += 1
                            self.proteins_dic[protein_id] = {
                                "name": name,
                                "species": species
                            }
                            # Add to cache
                            self.uniprot_cache[protein_id] = (name, species)
                        else:
                            uniprot_not_found += 1
                            self.proteins_dic[protein_id] = {
                                "name": protein_id,
                                "species": "Unknown"
                            }
                            # Cache the negative result too
                            self.uniprot_cache[protein_id] = (protein_id, "Unknown")
            
            # Show final summary
            self.status_area.clear_output(wait=True)
            display(HTML(f"""
                <div class="fetch-status">
                    <h4 style="color:green;"><b>Protein Processing Complete!</b></h4>
                    <div class="summary">
                        <h4>Final Summary:</h4>
                        <ul>
                            <li>Total proteins processed: {total_proteins}</li>
                            <li>Proteins from cache: {cached_proteins}</li>
                            <li>Multiple entry proteins: {multiple_entries}</li>
                            {"<li>UniProt matches found: " + str(uniprot_found) + "</li>" if fetch_from_uniprot else ""}
                            {"<li>UniProt matches not found: " + str(uniprot_not_found) + "</li>" if fetch_from_uniprot else ""}
                        </ul>
                    </div>
                </div>
            """))
        
        # Mark processing as complete
        self._protein_processing_complete = True
        
        # Update the merged_df with the protein information
        if hasattr(self, 'merged_df') and self.merged_df is not None:
            # Add protein_name and protein_species columns if they don't exist
            if 'protein_name' not in self.merged_df.columns:
                self.merged_df['protein_name'] = ''
            if 'protein_species' not in self.merged_df.columns:
                self.merged_df['protein_species'] = ''
            
            # Update the columns with the fetched information
            for protein_id, info in self.proteins_dic.items():
                mask = self.merged_df['Master Protein Accessions'] == protein_id
                self.merged_df.loc[mask, 'protein_name'] = info['name']
                self.merged_df.loc[mask, 'protein_species'] = info['species']
        
        # Return the number of proteins processed
        return len(self.proteins_dic)

    def fetch_uniprot_info_batch(self, protein_ids, max_retries=3, timeout=30):
        """
        Fetch protein information for multiple proteins at once using UniProt's batch API.
        Returns a dictionary mapping protein IDs to (name, species) tuples.
        """
        if not protein_ids:
            return {}
        
        results = {}
        
        try:
            # Use the batch REST API endpoint
            batch_url = 'https://rest.uniprot.org/uniprotkb/search'
            
            # Create a query with OR conditions for each protein ID
            query = ' OR '.join([f'accession:{pid}' for pid in protein_ids])
            
            # Parameters for the request
            params = {
                'query': query,
                'format': 'json',
                'fields': 'accession,protein_name,organism_name',
                'size': len(protein_ids)  # Request all results in one response
            }
            
            # Make the request with timeout and retry logic
            retry_count = 0
            while retry_count < max_retries:
                try:
                    response = requests.get(batch_url, params=params, timeout=timeout)
                    
                    # Handle rate limiting
                    if response.status_code == 429:  # Too Many Requests
                        retry_after = int(response.headers.get('Retry-After', 5))
                        print(f"Rate limited. Waiting {retry_after} seconds...")
                        time.sleep(retry_after + 1)  # Add 1 second buffer
                        retry_count += 1
                        continue
                    
                    # Break if successful
                    if response.status_code == 200:
                        break
                    
                    # Handle other errors
                    response.raise_for_status()
                    
                except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
                    retry_count += 1
                    wait_time = 2 ** retry_count  # Exponential backoff
                    print(f"Request failed: {str(e)}. Retrying in {wait_time} seconds...")
                    time.sleep(wait_time)
            
            # Process the response
            if response.status_code == 200:
                data = response.json()
                
                for entry in data.get('results', []):
                    accession = entry.get('primaryAccession')
                    
                    # Extract protein name
                    protein_name = None
                    protein_data = entry.get('proteinDescription', {})
                    
                    # Try to find a common/short name first
                    if 'recommendedName' in protein_data:
                        if 'shortNames' in protein_data['recommendedName']:
                            protein_name = protein_data['recommendedName']['shortNames'][0]['value']
                        else:
                            protein_name = protein_data['recommendedName'].get('fullName', {}).get('value')
                    
                    # Try alternative names if no recommended name found
                    if not protein_name and 'alternativeNames' in protein_data and protein_data['alternativeNames']:
                        if 'shortNames' in protein_data['alternativeNames'][0]:
                            protein_name = protein_data['alternativeNames'][0]['shortNames'][0]['value']
                        else:
                            protein_name = protein_data['alternativeNames'][0].get('fullName', {}).get('value')
                    
                    # Extract species name
                    species = None
                    organism_data = entry.get('organism', {})
                    organism_names = organism_data.get('names', [])
                    
                    # Try to find common name first
                    for name in organism_names:
                        if name['type'] == 'common':
                            species = name['value']
                            break
                    
                    # Fallback to scientific name
                    if not species and organism_names:
                        for name in organism_names:
                            if name['type'] == 'scientific':
                                species = name['value']
                                break
                    
                    # Clean up protein name if found
                    if protein_name:
                        protein_name = protein_name.split(' precursor')[0].split(' (')[0]
                    
                    # Store the results
                    if accession and (protein_name or species):
                        results[accession] = (protein_name or accession, species or "Unknown")
            
            return results
        
        except Exception as e:
            print(f"Error in batch fetch: {str(e)}")
            return {}
        
    def _load_merged_data(self, file_data):
        """
        Load and validate merged data file
        Returns tuple of (dataframe, status)
        """
        try:
            content = bytes(file_data.content)
            filename = file_data.name
            extension = filename.split('.')[-1].lower()

            file_stream = io.BytesIO(content)

            # Load data based on file extension
            try:
                if extension == 'csv':
                    df = pd.read_csv(file_stream)
                elif extension in ['txt', 'tsv']:
                    df = pd.read_csv(file_stream, delimiter='\t')
                elif extension == 'xlsx':
                    df = pd.read_excel(file_stream)
                else:
                    display(HTML(f'<b style="color:red;">Error: Unsupported file format</b>'))
                    return None, 'no'
            except Exception as e:
                display(HTML(f'<b style="color:red;">Error reading file: {str(e)}</b>'))
                return None, 'no'

            # Check for protein info columns and notify user
            missing_columns = []
            if 'protein_name' not in df.columns:
                missing_columns.append('protein_name')
                df['protein_name'] = ''
            if 'protein_species' not in df.columns:
                missing_columns.append('protein_species')
                df['protein_species'] = ''
                
            if missing_columns:
                notification = f"""
                <div style="padding: 10px; margin: 10px 0;">
                    <p style="color: #17a2b8; margin: 0;">
                        <b>Notice:</b> The following columns are missing from your data:
                        <ul style="color: #17a2b8; margin: 5px 0;">
                            {''.join(f'<li>{col}</li>' for col in missing_columns)}
                        </ul>
                        </p>
                        <p style="color: #17a2b8; margin: 0;">
                        UniProt will be searched to automatically fill in this information. <br>
                        Alternativly you can upload a standardized file from the data transomation module with the protein information. 
                    </p>
                </div>
                """
                display(HTML(notification))

            # Validate and clean data
            cleaned_df, warnings, errors = self._validate_and_clean_data(df)

            # Warnings about invalid/blank values are commented out
            # if warnings:
            #     warning_html = "<br>".join([
            #         f'<b style="color:orange;">Warning: {w}</b>'
            #         for w in warnings
            #     ])
            #     display(HTML(warning_html))

            # Display errors if any
            if errors:
                error_html = "<br>".join([
                    f'<b style="color:red;">Error: {e}</b>'
                    for e in errors
                ])
                display(HTML(error_html))
                return None, 'no'

            if cleaned_df is not None and len(cleaned_df) > 0:
                # Process protein information
                num_proteins = self._process_protein_info(cleaned_df)
                
                # Add information about remaining rows and processed proteins
                success_message = f"""
                <div style="padding: 10px; margin: 10px 0; border-left: 4px solid #28a745; background-color: #f8f9fa;">
                    <p style="color: #28a745; margin: 0;">
                        <b>Data Import Complete!</b><br>
                        • Data imported successfully with {cleaned_df.shape[0]} rows and {cleaned_df.shape[1]} columns.<br>
                        • Processed data contains {len(cleaned_df)} rows after removing blank values.<br>
                        • Successfully processed information for {num_proteins} unique proteins.
                    </p>
                </div>
                """
                #display(HTML(success_message))
                return cleaned_df, 'yes'
            else:
                display(HTML('<b style="color:red;">Error: No valid data rows remaining after cleaning</b>'))
                return None, 'no'

        except Exception as e:
            display(HTML(f'<b style="color:red;">Error processing file: {str(e)}</b>'))
            return None, 'no'

    def _on_merged_upload_change(self, change):
        """Handle merged data file upload"""
        if change['type'] == 'change' and change['name'] == 'value':
            with self.output_area:
                self.output_area.clear_output()
                if change['new'] and len(change['new']) > 0:
                    file_data = change['new'][0]
                    df, status = self._load_merged_data(file_data)
                    if status == 'yes' and df is not None:
                        self.merged_df = df  # Only set merged_df if validation passed
                        display(HTML(
                            f'<b style="color:green;">Data imported successfully with '
                            f'{df.shape[0]} rows and {df.shape[1]} columns.</b>'
                        ))

    def fetch_uniprot_info(self, protein_id):
        """
        Fetch protein information from UniProt, prioritizing common names.
        Returns tuple of (protein_common_name, species_common_name) or (None, None) if not found.
        """
        try:
            # Try REST API first
            rest_url = f'https://rest.uniprot.org/uniprotkb/{protein_id}.json'
            response = requests.get(rest_url)
            
            if response.status_code == 200:
                data = response.json()
                
                # Get protein common name
                try:
                    # Look for protein names
                    names = data['proteinDescription']
                    protein_name = None
                    
                    # Try to find a common/short name
                    if 'shortNames' in names.get('recommendedName', {}):
                        protein_name = names['recommendedName']['shortNames'][0]['value']
                    elif 'shortNames' in names.get('alternativeNames', [{}])[0]:
                        protein_name = names['alternativeNames'][0]['shortNames'][0]['value']
                    else:
                        # Fallback to full name
                        protein_name = names.get('recommendedName', {}).get('fullName', {}).get('value')
                    
                    # Get species common name
                    organism_data = data['organism']
                    species = None
                    for name in organism_data.get('names', []):
                        if name['type'] == 'common':
                            species = name['value']
                            break
                    if not species:  # Fallback to scientific name
                        species = organism_data.get('scientificName')
                    
                    return protein_name, species
                    
                except KeyError:
                    pass  # Fall through to XML approach
            
            # Fall back to XML API
            xml_url = f'https://www.uniprot.org/uniprot/{protein_id}.xml'
            response = requests.get(xml_url)
            
            if response.status_code != 200:
                return None, None
                
            root = ElementTree.fromstring(response.content)
            ns = {'up': 'http://uniprot.org/uniprot'}
            
            # Get protein common name with fallbacks
            protein_name = None
            # Try short name first
            name_element = (
                root.find('.//up:recommendedName/up:shortName', ns) or
                root.find('.//up:alternativeName/up:shortName', ns) or
                root.find('.//up:recommendedName/up:fullName', ns) or
                root.find('.//up:submittedName/up:fullName', ns)
            )
            protein_name = name_element.text if name_element is not None else None
            
            # Get species common name
            species = None
            # Try common name first
            organism = root.find('.//up:organism/up:name[@type="common"]', ns)
            if organism is not None:
                species = organism.text
            else:
                # Fallback to scientific name
                organism = root.find('.//up:organism/up:name[@type="scientific"]', ns)
                species = organism.text if organism is not None else "Unknown"
            
            if protein_name:
                # Clean up protein name - remove any "precursor" or similar suffixes
                protein_name = protein_name.split(' precursor')[0].split(' (')[0]
                return protein_name, species
            else:
                return None, None
            
        except Exception as e:
            print(f"Error fetching UniProt data for {protein_id}: {str(e)}")
            return None, None




In [3]:
class PlotState:
    def __init__(self):
        self.current_state = {
            'uploadedData': None,
            'topProteins': None,
            'groupSelection': None,
            'xLabel': None,
            'yLabel': None,
            'colorScheme': None,
            'lastGenerated': None,
            'buttonsLocked': True
        }
    
    def update_state(self, **kwargs):
        self.current_state.update(kwargs)
        # Lock buttons when state changes
        self.current_state['buttonsLocked'] = True
    
    def generate_completed(self):
        self.current_state['buttonsLocked'] = False
        self.current_state['lastGenerated'] = {
            key: value for key, value in self.current_state.items() 
            if key not in ['lastGenerated', 'buttonsLocked']
        }
    
    def get_state(self):
        return self.current_state


In [4]:
 
class ProteinPlotter:
    def __init__(self, data_transformer):
        self.data_transformer = data_transformer
        self.plot_output = widgets.Output()
        self.info_output = widgets.Output()
        self.export_output = widgets.Output()
        self.proteins_df = None
        self.sum_df = None
        
        # Initialize state manager first
        self.state_manager = PlotState()
    
        # Create widgets for protein plotting
        self.num_proteins_widget = widgets.IntSlider(
            value=10,
            min=1,
            max=50,
            step=1,
            description='Top Number of Proteins:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px')
        )
        
        # Create widgets for protein plotting
        self.num_proteins_widget = widgets.IntSlider(
            value=10,
            min=1,
            max=50,
            step=1,
            description='Top Number of Proteins:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px')
        )


        self.download_plot_button = widgets.Button(
            description='Download Interactive Plot',
            button_style='info',
            icon='file',
            layout=widgets.Layout(width='200px')
        )

        # Create multi-select widget for groups
        self.group_select = widgets.SelectMultiple(
            options=[],
            description='Groups:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px', height='100px')
        )
        
        self.plot_button = widgets.Button(
            description='Generate/Update Data',
            button_style='success',
            icon='refresh',
            layout=widgets.Layout(width='200px')
        )
        
        self.export_button = widgets.Button(
            description='Export Data',
            button_style='info',
            icon='download',
            layout=widgets.Layout(width='200px')
        )

        # Add label customization widgets
        self.xlabel_widget = widgets.Text(
            description='X Label:',
            placeholder='Enter x-axis label',
            layout=widgets.Layout(width='300px')
        )
        
        self.ylabel_widget = widgets.Text(
            description='Y Label:',
            placeholder='Enter y-axis label',
            layout=widgets.Layout(width='300px')
        )
        self.legend_widget = widgets.Text(
            description='Legend Title',
            placeholder='Enter a custom legend title',
            layout=widgets.Layout(width='300px')
        )
        self.title_widget = widgets.Text(
            description='Plot Title',
            placeholder='Enter a custom plot title',
            layout=widgets.Layout(width='300px')
        )


        # Update color scheme dropdown with categorized options
        color_schemes = [
            '--- DEFAULT (HSV)---',
            'HSV',  # Default option
            '--- QUALITATIVE (BEST FOR CATEGORIES) ---',
            'Plotly', 'D3', 'G10', 'T10', 'Alphabet', 
            'Set1', 'Set2', 'Set3', 'Pastel1', 'Pastel2', 'Paired',
            '--- SEQUENTIAL ---',
            'Viridis', 'Cividis', 'Inferno', 'Magma', 'Plasma',
            'Hot', 'Jet', 'Blues', 'Greens', 'Reds', 'Purples', 'Oranges',
            '--- DIVERGING ---',
            'Spectral', 'RdBu', 'RdYlBu', 'RdYlGn', 'PiYG', 'PRGn', 'BrBG', 'RdGy',
            '--- CYCLICAL ---',
            'IceFire', 'Edge', 'Twilight'
        ]
        
        self.color_scheme = widgets.Dropdown(
            options=color_schemes,
            value='HSV',  # Default value
            description='Color Scheme:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px')
        )

        # Add an inversion toggle radio button
        self.invert_plot = widgets.RadioButtons(
            description='Plot Orientation:',
            options=['By Sample', 'By Protein'],
            value='By Sample',  # Default selection
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='300px', height='auto'),
            disabled=False,
            indent=True  # Keeps options aligned with description instead of appearing below
        )
        

        # Create layout
        self.widget_box = widgets.VBox([
            widgets.HTML("<h4>Plot Controls:</h4>"),
            self.num_proteins_widget,
            self.group_select,
            widgets.HTML("<h4>Appearance Settings:</h4>"),
            self.invert_plot,
            self.xlabel_widget,
            self.ylabel_widget,
            self.legend_widget,
            self.title_widget,
            self.color_scheme,
            widgets.HTML("<h4>Actions:</h4>"),
            self.plot_button,
            self.export_button,
            self.download_plot_button,
            self.info_output,
            self.plot_output,
            self.export_output
        ])
            
            
        # Set initial button states
        self.export_button.disabled = True
        self.download_plot_button.disabled = True
        
        def _on_input_change(change):
            self.state_manager.update_state()
            self.export_button.disabled = True
            self.download_plot_button.disabled = True
        
        # Add observers for input changes
        self.num_proteins_widget.observe(_on_input_change, names='value')
        self.group_select.observe(_on_input_change, names='value')
        self.xlabel_widget.observe(_on_input_change, names='value')
        self.ylabel_widget.observe(_on_input_change, names='value')
        self.legend_widget.observe(_on_input_change, names='value')
        self.title_widget.observe(_on_input_change, names='value')
        self.color_scheme.observe(_on_input_change, names='value')

        # Add button click handlers
        self.plot_button.on_click(self._on_plot_button_click)
        self.export_button.on_click(self._on_export_button_click)
        self.download_plot_button.on_click(self._on_download_plot_click)
        
        # Add variable to store current figure
        self.current_fig = None
        # Add observer for data changes
        self.data_transformer.merged_uploader.observe(self._update_group_options, names='value')



        # In your setup_widgets method, add this line:
        self.color_scheme.observe(self._on_color_scheme_change, names='value')

    # Then add this method:
    def _on_color_scheme_change(self, change):
        """Update plot when color scheme changes"""
        if self.current_fig is not None and hasattr(self, 'plot_button'):
            # Trigger plot update by simulating a button click
            self._on_plot_button_click(None)

    def _get_avg_columns(self):
        """Get all columns that start with 'Avg_' from the merged dataframe"""
        if self.data_transformer.merged_df is not None:
            return [col for col in self.data_transformer.merged_df.columns if col.startswith('Avg_')]
        return []

    def _update_group_options(self, change):
        """Update group selection options when data changes"""
        if self.data_transformer.merged_df is not None:
            avg_columns = self._get_avg_columns()
            # Remove 'Avg_' prefix for display
            group_options = [col.replace('Avg_', '') for col in avg_columns]
            self.group_select.options = group_options
            # Select all groups by default
            self.group_select.value = group_options

    def process_data(self, selected_groups=None):
        """Process data with optional group selection"""
        if self.data_transformer.merged_df is None or not self.data_transformer.proteins_dic:
            return False

        df = self.data_transformer.merged_df.copy()
        
        # Get Absorbance columns based on selected groups
        if selected_groups:
            Absorbance_cols = [f'Avg_{var}' for var in selected_groups]
        else:
            Absorbance_cols = self._get_avg_columns()
            
        df['Total_Absorbance'] = df[Absorbance_cols].sum(axis=1).astype(int)
        
        # Filter out zero Absorbance entries
        result_df = df[['unique ID', 'Total_Absorbance']]
        result_df = result_df[result_df['Total_Absorbance'] == 0]
        all_zero_list = list(result_df['unique ID'])
        peptides_df = df[~df['unique ID'].isin(all_zero_list)]

        # Process protein positions and create proteins DataFrame
        additional_columns = ['Master Protein Accessions', 'unique ID']
        selected_columns = additional_columns + Absorbance_cols
        
        peptides_df.loc[:, 'Master Protein Accessions'] = peptides_df['Master Protein Accessions']
        
        temp_df = peptides_df.copy()
        temp_df.loc[:, 'Protein_ID'] = temp_df['Master Protein Accessions']
        
        # Create proteins DataFrame with selected columns
        self.proteins_df = temp_df.groupby('Protein_ID').agg(
            {**{col: 'first' for col in ['Master Protein Accessions']},
             **{col: 'sum' for col in Absorbance_cols}}
        ).reset_index()
        
        # Calculate relative Absorbance for selected groups
        for col in Absorbance_cols:
            col_sum = self.proteins_df[col].sum()
            if col_sum > 0:  # Avoid division by zero
                self.proteins_df[f'Rel_{col}'] = (self.proteins_df[col] / col_sum) * 100
            else:
                self.proteins_df[f'Rel_{col}'] = 0
            
        # Create sum DataFrame for selected groups
        self.sum_df = pd.DataFrame({
            'Sample': Absorbance_cols,
            'Total_Sum': [self.proteins_df[col].sum() for col in Absorbance_cols]
        })
        
        # Add protein descriptions
        name_list = []
        for _, row in self.proteins_df.iterrows():
            if ',' in row['Protein_ID']:
                strrow = row['Protein_ID'].split(',')
                named_combo = self._fetch_protein_names('; '.join(strrow))
            else:
                named_combo = self._fetch_protein_names(row['Protein_ID'])
            name_list.append(named_combo)
        
        # Drop the 'Protein_ID' column
        self.proteins_df = self.proteins_df.drop(columns=['Protein_ID'])    
        
        self.proteins_df['Description'] = name_list
        self.proteins_df['Description'] = self.proteins_df['Description'].astype(str).str.replace(r"['\['\]]", "", regex=True)
        
        return True

    def _fetch_protein_names(self, accession_str):
        """
        Fetch protein names from the proteins dictionary.
        Returns a list of protein names, using the full protein name.
        """
        names = []
        for acc in accession_str.split('; '):
            if acc in self.data_transformer.proteins_dic:
                # Use the full protein name instead of splitting it
                name = self.data_transformer.proteins_dic[acc]['name']
                names.append(name)
            else:
                names.append(acc)
        return names

    def _get_color_sequence(self, n_colors):
        """Get color sequence based on selected scheme."""
        if n_colors <= 0:
            return []
        
        try:
            scheme = self.color_scheme.value
            
            # Skip header options that start with '---'
            if scheme.startswith('---'):
                scheme = 'HSV'  # Default to HSV if a header is selected
            
            # Handle special cases
            if scheme.lower() in ['rainbow', 'hsv']:
                return [f'hsl({h},70%,60%)' for h in np.linspace(0, 330, n_colors)]
            
            # Try qualitative color scales first (best for categorical data)
            color_sequence = getattr(px.colors.qualitative, scheme, None)
            if color_sequence is None:
                # Try sequential color scales
                color_sequence = getattr(px.colors.sequential, scheme, None)
            if color_sequence is None:
                # Try diverging color scales
                color_sequence = getattr(px.colors.diverging, scheme, None)
            if color_sequence is None:
                # Try cyclical color scales
                color_sequence = getattr(px.colors.cyclical, scheme, None)
            
            if color_sequence:
                if n_colors >= len(color_sequence):
                    # If we need more colors than available, interpolate
                    indices = np.linspace(0, len(color_sequence)-1, n_colors)
                    return [color_sequence[int(i)] for i in indices]
                else:
                    # If we need fewer colors, take a subset
                    indices = np.linspace(0, len(color_sequence)-1, n_colors, dtype=int)
                    return [color_sequence[i] for i in indices]
            
            # Default to HSV if no matching scheme found
            return [f'hsl({h},70%,60%)' for h in np.linspace(0, 330, n_colors)]
            
        except Exception as e:
            print(f"Error generating colors: {e}")
            # Fallback to HSV
            return [f'hsl({h},70%,60%)' for h in np.linspace(0, 330, n_colors)]   
  
    def plot_stacked_bar_scaled(self, pro_list, title, selected_groups):
        if self.proteins_df is None or self.sum_df is None:
            return None
                
        scaled_df = self.proteins_df.copy()
        
        # Filter for selected groups
        sample_orders = [f'Rel_Avg_{var}' for var in selected_groups]
        
        # Create sample mapping for selected groups
        sample_mapping = {
            f'Rel_Avg_{var}': f'Avg_{var}' for var in selected_groups
        }
        
        # Convert relative absorbance columns to percentages
        for col in sample_orders:
            scaled_df[col] = self.proteins_df[col] * 100
                
        # Calculate total sums for each group first
        total_sums = {}
        for col in sample_orders:
            sample_key = sample_mapping[col]
            total_sum = self.sum_df.loc[self.sum_df['Sample'] == sample_key, 'Total_Sum'].values[0]
            total_sums[col.replace('Rel_Avg_', '')] = total_sum  # Store with clean group name
        
        # Scale absolute Absorbance
        for col in sample_orders:
            sample_key = sample_mapping[col]
            total_sum = self.sum_df.loc[self.sum_df['Sample'] == sample_key, 'Total_Sum'].values[0]
            if total_sum > 0:
                scaled_df[col] = self.proteins_df[col] * total_sum / self.proteins_df[col].sum()
        
        # Create a new DataFrame for the Minor Proteins
        minor_proteins_df = pd.DataFrame()
        
        # For each sample, calculate the sum of proteins not in pro_list
        for sample in sample_orders:
            abs_col = sample_mapping[sample]
            # Sum absorbance for proteins not in pro_list
            minor_proteins_absorbance = scaled_df[~scaled_df['Description'].isin(pro_list)][sample].sum()
            minor_proteins_rel_absorbance = self.proteins_df[~self.proteins_df['Description'].isin(pro_list)][sample].sum()
            
            # Add to the minor proteins DataFrame
            if len(minor_proteins_df) == 0:
                minor_proteins_df = pd.DataFrame({
                    'Description': ['Minor Proteins'],
                    'Master Protein Accessions': ['Minor Proteins'],
                    sample: [minor_proteins_absorbance],
                    abs_col: [minor_proteins_absorbance]
                })
            else:
                minor_proteins_df[sample] = [minor_proteins_absorbance]
                minor_proteins_df[abs_col] = [minor_proteins_absorbance]
        
        # Filter scaled_df to only include proteins in pro_list
        scaled_df = scaled_df[scaled_df['Description'].isin(pro_list)]
        
        # Sort proteins_df based on pro_list
        description_order = {desc: i for i, desc in enumerate(pro_list)}
        scaled_df['Order'] = scaled_df['Description'].map(description_order)
        scaled_df = scaled_df.sort_values(by='Order').reset_index(drop=True)
        
        # Append minor proteins to the end
        scaled_df = pd.concat([scaled_df, minor_proteins_df], ignore_index=True)
        
        # Create figure object before adding traces
        fig = go.Figure()

            # Calculate total absorbance for each sample
        total_absorbance = {}
        for sample in sample_orders:
            total_absorbance[sample] = scaled_df[sample].sum()


        if self.invert_plot.value == 'By Sample':
            # Get colors based on selected color scheme
            colors = self._get_color_sequence(len(pro_list))
            # Add gray color for Minor Proteins
            colors.append('#808080')  # Gray color for Minor Proteins
        

            # Filter out zero values for plotting
            for idx, row in scaled_df.iterrows():
                protein_description = row['Description']
                if protein_description in pro_list or protein_description == 'Minor Proteins':
                    # Use gray for Minor Proteins, otherwise use the color from the sequence
                    if protein_description == 'Minor Proteins':
                        color = '#808080'  # Gray color
                    else:
                        color = colors[pro_list.index(protein_description)]
                    
                    # Only plot non-zero values
                    x_values = []
                    y_values = []

                    for i, sample in enumerate(sample_orders):
                        value = row[sample]
                        if value > 0:  # Only include non-zero values
                            x_values.append(sample.replace('Rel_Avg_', ''))
                            y_values.append(value)
                            
                            abs_col = sample_mapping[sample]
                            if protein_description == 'Minor Proteins':
                                rel_value_hov = row[sample] / total_absorbance[sample] * 100 if total_absorbance[sample] > 0 else 0
                            else:
                                rel_value_hov = self.proteins_df.loc[self.proteins_df['Description'] == protein_description, sample].values[0]
                            abs_value = row[abs_col]
                            

                    
                    if y_values:  # Only add trace if there are non-zero values
                        fig.add_trace(go.Bar(
                            name=protein_description,
                            x=[label.replace('Rel_Avg_', '') for label in sample_orders],  # Keep all x values for alignment
                            y=row[sample_orders],  # Keep all y values, including zeros
                            marker_color=color,
                            hovertext=[
                                f"Protein: {row['Master Protein Accessions']}<br>" +
                                f"Description: {row['Description']}<br>" +
                                #f"Sample: {sample.replace('Rel_Avg_', '')}<br>" +
                                f"Relative Absorbance: {self.proteins_df.loc[self.proteins_df['Description'] == protein_description, s].values[0] if protein_description != 'Minor Proteins' else row[s] / total_absorbance[s] * 100 if total_absorbance[s] > 0 else 0:.2f}%<br>" +
                                f"Absolute Absorbance: {row[sample_mapping[s]]:.2e}<br>"
                                for s in sample_orders
                            ],
                            hoverinfo='text'
                        ))
        else:
            # Inverted plotting logic (proteins on x-axis)
            # Get colors based on selected color scheme
            colors = self._get_color_sequence(len(selected_groups))
            
            # For each sample, create a trace
            for i, sample in enumerate(sample_orders):
                sample_display = sample.replace('Rel_Avg_', '')
                abs_col = sample_mapping[sample]
                
                # Prepare data for this sample
                x_values = []  # Protein names
                y_values = []  # Values for this sample

                for _, row in scaled_df.iterrows():
                    protein_description = row['Description']
                    if protein_description in pro_list or protein_description == 'Minor Proteins':
                        value = row[sample]
                        if value > 0:  # Only include non-zero values
                            x_values.append(protein_description)
                            y_values.append(value)
                            
                            # Create hover text
                            if protein_description == 'Minor Proteins':
                                rel_value_hov = row[sample] / total_absorbance[sample] * 100 if total_absorbance[sample] > 0 else 0
                            else:
                                rel_value_hov = self.proteins_df.loc[self.proteins_df['Description'] == protein_description, sample].values[0]
                            

                if y_values:  # Only add trace if there are non-zero values
                    fig.add_trace(go.Bar(
                        name=sample_display,  # Use sample name for legend
                        x=pro_list + (['Minor Proteins'] if 'Minor Proteins' in scaled_df['Description'].values else []),
                        y=scaled_df[sample],  # Use all values to maintain alignment
                        marker_color=colors[i],
                        hovertext=[
                            #f"Protein: {row['Master Protein Accessions']}<br>" +
                            #f"Description: {row['Description']}<br>" +
                            f"Sample: {sample_display}<br>" +
                            f"Relative Absorbance: {self.proteins_df.loc[self.proteins_df['Description'] == row['Description'], sample].values[0] if row['Description'] != 'Minor Proteins' else row[sample] / total_absorbance[sample] * 100 if total_absorbance[sample] > 0 else 0:.2f}%<br>" +
                            f"Absolute Absorbance: {row[abs_col]:.2e}<br>"
                            for _, row in scaled_df.iterrows()
                        ],
                        hoverinfo='text'
                    ))

        
        # Get custom labels
        x_label = self.xlabel_widget.value or ('Proteins' if self.invert_plot.value == 'By Protein' else 'Samples')
        y_label = self.ylabel_widget.value or 'Summed Absorbance'
        plot_title = self.title_widget.value or title or 'Protein Distribution Analysis'
        legend_title = self.legend_widget.value or ('Samples' if self.invert_plot.value == 'By Protein' else 'Protein Origins')

        fig.update_layout(
            barmode='stack',
            title={
                'text': plot_title,
                'y': 0.95,
                'x': 0.5,
                'xanchor': 'center',
                'yanchor': 'top',
                'font': {"size": 18, 'color': 'black'}
            },
        xaxis_title=x_label,
        yaxis_title=y_label,

        #yaxis_type='log',  # Set y-axis to log scale
        yaxis=dict(
            showline=True,
            gridcolor='lightgray',
            showgrid=True,
            showticklabels=True,  # Hide tick labels on y-axis
            linewidth=1,
            linecolor='black',
            mirror=False,
            zeroline=False  # Don't show zero line
        ),
        xaxis=dict(
            showline=True,
            linewidth=1,
            linecolor='black',
            mirror=False,
            tickangle=-90 if self.invert_plot.value == 'By Sample' else 45  # Adjust tick angle based on orientation
        ),
        legend_title=legend_title,
        legend={
            'yanchor': "top",
            'y': 0.95,
            'xanchor': "left",
            'x': 1.05,
            'traceorder': 'normal',
            'font': {"size": 14, 'color': 'black'},
            'bgcolor': 'rgba(255, 255, 255, 0.9)'
        },
        showlegend=True,
        template='plotly_white',
        height=820,
        width=1200,
        margin=dict(
            t=100,
            l=100,
            r=100,
            b=100
        ),
        hoverlabel=dict(
            bgcolor="white",
            font_size=14,
            font_family="Arial"
        ))
        
        fig.update_xaxes(
            tickangle=45,
            title_font={"size": 18},
            tickfont={"size": 16},
            tickfont_color="black",  # Black tick labels
            title_font_color="black",  # Black axis title                
        )
        
        fig.update_yaxes(
            title_font={"size": 18},
            tickfont={"size": 16},
            tickfont_color="black",  # Black tick labels
            title_font_color="black",  # Black axis title
            gridcolor="lightgray",  # Light gray grid lines
            showgrid=True,  # Show grid lines
            zeroline=False,  # Hide zero line
            exponentformat='E',
            showexponent='all',
            tickformat=".1e"
        )
        # Always add scatter trace for totals, but calculate differently based on orientation
        if self.invert_plot.value == 'By Sample':
            # Original sample-wise totals calculation
            fig.add_trace(go.Scatter(
                x=[label.replace('Rel_Avg_', '') for label in sample_orders],
                y=[total_sums[group.replace('Rel_Avg_', '')] for group in sample_orders],
                mode='text',
                text=[f"{total_sums[group.replace('Rel_Avg_', '')]:.2e}" for group in sample_orders],
                textposition='top center',
                textfont=dict(size=12, color='black'),
                showlegend=True,
                name='Show Summed Absorbance',
                hoverinfo='none',  # Remove hover info for summed absorbance
                texttemplate='%{text}'
            ))
        else:
            # Calculate protein-wise totals
            protein_sums = {}
            for protein in pro_list + (['Minor Proteins'] if 'Minor Proteins' in scaled_df['Description'].values else []):
                # Sum all sample values for this protein
                protein_values = scaled_df[scaled_df['Description'] == protein][sample_orders].values[0]
                protein_total = sum(protein_values)
                protein_sums[protein] = protein_total

            # Add the protein totals trace
            fig.add_trace(go.Scatter(
                x=list(protein_sums.keys()),
                y=list(protein_sums.values()),
                mode='text',
                text=[f"{value:.2e}" for value in protein_sums.values()],
                textposition='top center',
                textfont=dict(size=12, color='black'),
                showlegend=True,
                name='Show Summed Absorbance',
                hoverinfo='none',  # Remove hover info for summed absorbance
                texttemplate='%{text}'
            ))       
            
        # Mark generation as complete
        self.state_manager.generate_completed()
        self.export_button.disabled = False
        self.download_plot_button.disabled = False
        
        return fig
    
    def _on_plot_button_click(self, b):
        if self.current_fig is not None:
            self.state_manager.generate_completed()
            self.export_button.disabled = False
            self.download_plot_button.disabled = False
        with self.plot_output:
            self.plot_output.clear_output(wait=True)
            
            if not self.group_select.value:
                print("Please select at least one group to plot.")
                return
            
            selected_groups = list(self.group_select.value)
            if self.process_data(selected_groups):
                # Calculate average Absorbance for sorting using only selected groups
                selected_avg_columns = [f'Avg_{var}' for var in selected_groups]
                #self.proteins_df['avg_absorbance_all'] = self.proteins_df[selected_avg_columns].mean(axis=1)
                                
                # Calculate sum of all selected columns
                total_sum =  self.proteins_df[selected_avg_columns].sum().sum()
                
                # Calculate row sums
                row_sums =  self.proteins_df[selected_avg_columns].sum(axis=1)
                
                # Calculate relative percentage contribution
                self.proteins_df['avg_absorbance_all'] = (row_sums / total_sum * 100).round(2)
                
                # Sort and get top N proteins
                self.proteins_df = self.proteins_df.sort_values('avg_absorbance_all', ascending=False)
                
                # Sort and get top N proteins
                pro_list = list(self.proteins_df ['Description'].head(self.num_proteins_widget.value))
                
                # Create and store plot
                self.current_fig = self.plot_stacked_bar_scaled(
                    pro_list=pro_list,
                    title='Protein Distribution Analysis',  # Add a default title here
                    selected_groups=selected_groups
                )
                if self.current_fig is not None:
                    self.current_fig.show()
            else:
                print("Please upload all required files first.")
                print("Error creating plot. Please check your data.")
    
    def generate_download_link(self, content, filename, filetype='text/csv'):
        """Generate a download link for any content"""
        if isinstance(content, pd.DataFrame):
            if filetype == 'text/csv':
                content = content.to_csv(index=False)
            else:
                content = content.to_csv(index=True)
        if isinstance(content, str):
            content = content.encode()
        b64 = base64.b64encode(content).decode()
        return f"""
            <a id="download_link" href="data:{filetype};base64,{b64}" 
               download="{filename}"
               style="display: none;">
                Download {filename}
            </a>
            <script>
                document.getElementById('download_link').click();
            </script>
            """
    
    def _on_download_plot_click(self, b):
        """Handle plot download button click with automatic download"""
        if self.current_fig is not None:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            plot_filename = f'protein_plot_{timestamp}.html'
            
            with self.export_output:
                self.export_output.clear_output(wait=True)
                display(HTML(self.generate_download_link(
                    self.current_fig.to_html(),
                    plot_filename,
                    'text/html'
                )))
        else:
            print("Please generate a plot first.")
    
    def _on_export_button_click(self, b):
        """Handle data export with automatic download"""
        if self.proteins_df is not None:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            data_filename = f'protein_absorbance_analysis_{timestamp}.csv'
            
            with self.export_output:
                self.export_output.clear_output(wait=True)
                display(HTML(self.generate_download_link(
                    self.proteins_df,
                    data_filename,
                    'text/csv'
                )))
        else:
            print("Please generate the analysis first.")

    def display(self):
        """Display the protein analysis interface"""
        display(self.widget_box)



In [None]:

# Initialize the interface
data_transformer = DataTransformation()
data_transformer.setup_data_loading_ui()

# Create protein plotter
protein_plotter = ProteinPlotter(data_transformer)
protein_plotter.display()

VBox(children=(HTML(value='<h4>Upload Data File:</h4>'), HBox(children=(FileUpload(value=(), accept='.csv,.txt…

Output(layout=Layout(margin='0 0 0 20px', width='300px'))

VBox(children=(HTML(value='<h4>Plot Controls:</h4>'), IntSlider(value=10, description='Top Number of Proteins:…

In [5]:

# Initialize the interface
data_transformer = DataTransformation()
data_transformer.setup_data_loading_ui()

# Create protein plotter
protein_plotter = ProteinPlotter(data_transformer)
protein_plotter.display()

VBox(children=(HTML(value='<h4>Upload Data File:</h4>'), HBox(children=(FileUpload(value=(), accept='.csv,.txt…

Output(layout=Layout(margin='0 0 0 20px', width='300px'))

VBox(children=(HTML(value='<h4>Plot Controls:</h4>'), IntSlider(value=10, description='Top Number of Proteins:…