In [1]:
import pandas as pd
import numpy as np
from datetime import datetime
import json, io, base64, re, os
import plotly.graph_objects as go
from IPython.display import display, HTML, clear_output
import plotly.express as px
import ipywidgets as widgets

# Initialize settings
import _settings as settings

# Global variables from settings
spec_translate_list = settings.SPEC_TRANSLATE_LIST


In [25]:
class DataTransformation:
    def __init__(self):
        self.merged_df = None
        self.group_data = None
        self.proteins_dic = {}
        self.output_area = None
        self.merged_uploader = None
        self.group_uploader = None
        self.fasta_uploader = None
        
    def create_download_link(self, file_path, label):
        """Create a download link for a file."""
        if os.path.exists(file_path):
            # Read file content and encode it as base64
            with open(file_path, 'rb') as f:
                content = f.read()
            b64_content = base64.b64encode(content).decode('utf-8')

            # Generate the download link HTML
            return widgets.HTML(f"""
                <a download="{os.path.basename(file_path)}" 
                   href="data:application/octet-stream;base64,{b64_content}" 
                   style="color: #0366d6; text-decoration: none; margin-left: 20px; font-size: 14px;">
                    {label}
                </a>
            """)
        else:
            # Show an error message if the file does not exist
            return widgets.HTML(f"""
                <span style="color: red; margin-left: 20px; font-size: 14px;">
                    File "{file_path}" not found!
                </span>
            """)

    def setup_data_loading_ui(self):
        """Initialize and display the data loading UI."""
        # Create file upload widgets
        self.merged_uploader = widgets.FileUpload(
            accept='.csv,.txt,.tsv,.xlsx',
            multiple=False,
            description='Upload Merged Data File',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )

        self.group_uploader = widgets.FileUpload(
            accept='.json',
            multiple=False,
            description='Upload Group Definition',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )

        self.fasta_uploader = widgets.FileUpload(
            accept='.fasta',
            multiple=True,
            description='Upload FASTA Files',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )

        # Reset button
        self.reset_button = widgets.Button(
            description='Reset',
            button_style='warning'
        )

        self.output_area = widgets.Output()

        # Create individual upload boxes with example links
        merged_box = widgets.HBox([
            self.merged_uploader,
            self.create_download_link("example_merged_dataframe.csv", "Example")
        ], layout=widgets.Layout(align_items='center'))

        group_box = widgets.HBox([
            self.group_uploader,
            self.create_download_link("example_group_definition.json", "Example")
        ], layout=widgets.Layout(align_items='center'))

        fasta_box = widgets.HBox([
            self.fasta_uploader,
            self.create_download_link("example_fasta.fasta", "Example")
        ], layout=widgets.Layout(align_items='center'))

        # Create left column with upload widgets
        upload_widgets = widgets.VBox([
            widgets.HTML("<h3><u>Upload Data Files:</u></h3>"),
            merged_box,
            widgets.HTML("<h3><u>Upload Group Definition:</u></h3>"),
            group_box,
            widgets.HTML("<h3><u>Upload Protein FASTA Files:</u></h3>"),
            fasta_box,
            self.output_area
        ], layout=widgets.Layout(
            width='400px',
            margin='0 20px 0 0'
        ))

        # Create container for status display
        self.status_area = widgets.Output(
            layout=widgets.Layout(
                width='400px',
                margin='0 0 0 20px'
            )
        )

        # Create grid layout
        grid = widgets.GridBox(
            [upload_widgets, self.status_area],
            layout=widgets.Layout(
                grid_template_columns='auto auto',
                grid_gap='20px',
                width='900px'
            )
        )

        # Register observers
        self.merged_uploader.observe(self._on_merged_upload_change, names='value')
        self.group_uploader.observe(self._on_group_upload_change, names='value')
        self.fasta_uploader.observe(self._on_fasta_upload_change, names='value')
        self.reset_button.on_click(self._reset_ui)

        # Display the grid
        display(grid)
        
    def _reset_ui(self, b):
        """Reset the UI state"""
        self.merged_uploader._counter = 0
        self.group_uploader._counter = 0
        self.fasta_uploader._counter = 0
        self.merged_uploader.value = ()
        self.group_uploader.value = ()
        self.fasta_uploader.value = ()
        self.merged_df = None
        self.group_data = None
        self.proteins_dic = {}
        with self.output_area:
            self.output_area.clear_output()
            display(HTML('<b style="color:blue;">All uploads cleared.</b>'))


    def _on_group_upload_change(self, change):
        """Handle group definition JSON file upload"""
        if change['type'] == 'change' and change['name'] == 'value':
            with self.output_area:
                self.output_area.clear_output()
                if change['new'] and len(change['new']) > 0:
                    file_data = change['new'][0]
                    try:
                        # Convert bytes content to string
                        content = file_data.content.tobytes().decode('utf-8')
                        self.group_data = json.loads(content)
                        
                        # Validate the group data structure
                        if not isinstance(self.group_data, dict):
                            raise ValueError("Group definition must be a dictionary")
                        
                        for group_key, group_info in self.group_data.items():
                            required_keys = {'grouping_variable', 'abundance_columns'}
                            if not required_keys.issubset(group_info.keys()):
                                raise ValueError(f"Group {group_key} missing required keys: {required_keys - group_info.keys()}")
                            if not isinstance(group_info['abundance_columns'], list):
                                raise ValueError(f"Group {group_key} abundance_columns must be a list")
                        
                        # Create list of group variables
                        group_vars = [group_info['grouping_variable'] for group_info in self.group_data.values()]
                        
                        display(HTML(
                            f'<b style="color:green;">Group definition file imported successfully with {len(self.group_data)} groups.</b><br>' + 
                            f'<b>Groups loaded: {", ".join(group_vars)}</b>'
                        ))
                        
                    except json.JSONDecodeError as e:
                        display(HTML(f'<b style="color:red;">Invalid JSON format: {str(e)}</b>'))
                    except ValueError as e:
                        display(HTML(f'<b style="color:red;">Invalid group definition format: {str(e)}</b>'))
                    except Exception as e:
                        display(HTML(f'<b style="color:red;">Error loading group definition file: {str(e)}</b>'))

    def _on_fasta_upload_change(self, change):
        if change['type'] == 'change' and change['name'] == 'value':
            with self.output_area:
                self.output_area.clear_output()
                if change['new'] and len(change['new']) > 0:
                    for file_data in change['new']:
                        try:
                            file_name = getattr(file_data, 'name', None)
                            if file_name and file_name.endswith('.fasta'):
                                new_proteins = self._parse_uploaded_fasta(file_data)
                                self.proteins_dic.update(new_proteins)
                                display(HTML(f'<b style="color:green;">Successfully imported FASTA file: {file_name} ({len(new_proteins)} proteins)</b>'))
                            else:
                                display(HTML(f'<b style="color:red;">Invalid file format. Please upload FASTA files only.</b>'))
                        except Exception as e:
                            display(HTML(f'<b style="color:red;">Error processing FASTA file: {str(e)}</b>'))
    
        
    def _validate_and_clean_data(self, df):
        """
        Validate and clean the uploaded data, dropping rows with blank values in key columns.
        Returns tuple of (cleaned_df, warnings, errors)
        """
        warnings = []
        errors = []
        
        # Check required columns exist
        required_columns = ['Master Protein Accessions', 'Positions in Proteins']
        missing = set(required_columns) - set(df.columns)
        if missing:
            errors.append(f"Missing required columns: {', '.join(missing)}")
            return None, warnings, errors
    
        cleaned_df = df.copy()
        
        # Handle blank values by dropping rows and issuing warnings
        for column in required_columns:
            blank_count = cleaned_df[column].isna().sum()
            if blank_count > 0:
                warnings.append(f"Dropping {blank_count} rows with blank values in {column} column")
                cleaned_df = cleaned_df.dropna(subset=[column])
        
        # Check for invalid characters in non-blank rows
        if len(cleaned_df) > 0:
            # Check Positions in Proteins
            invalid_pos = cleaned_df['Positions in Proteins'].apply(
                lambda x: ',' in str(x) or ':' in str(x)
            )
            if invalid_pos.any():
                errors.append(
                    "Found invalid characters (',' or ':') in Positions in Proteins column. "
                    "Please update the file and upload again."
                )
            
            # Check Master Protein Accessions
            invalid_acc = cleaned_df['Master Protein Accessions'].apply(
                lambda x: ',' in str(x) or ':' in str(x)
            )
            if invalid_acc.any():
                errors.append(
                    "Found invalid characters (',' or ':') in Master Protein Accessions column. "
                    "Please update the file and upload again."
                )
        
        return cleaned_df, warnings, errors
    
    def _load_merged_data(self, file_data):
        """
        Load and validate merged data file
        Returns tuple of (dataframe, status)
        """
        try:
            content = bytes(file_data.content)
            filename = file_data.name
            extension = filename.split('.')[-1].lower()
            
            file_stream = io.BytesIO(content)
            
            # Load data based on file extension
            try:
                if extension == 'csv':
                    df = pd.read_csv(file_stream)
                elif extension in ['txt', 'tsv']:
                    df = pd.read_csv(file_stream, delimiter='\t')
                elif extension == 'xlsx':
                    df = pd.read_excel(file_stream)
                else:
                    display(HTML(f'<b style="color:red;">Error: Unsupported file format</b>'))
                    return None, 'no'
            except Exception as e:
                display(HTML(f'<b style="color:red;">Error reading file: {str(e)}</b>'))
                return None, 'no'
                
            # Validate and clean data
            cleaned_df, warnings, errors = self._validate_and_clean_data(df)
            
            # Display warnings about dropped rows
            if warnings:
                warning_html = "<br>".join([
                    f'<b style="color:orange;">Warning: {w}</b>' 
                    for w in warnings
                ])
                display(HTML(warning_html))
            
            # Display errors if any
            if errors:
                error_html = "<br>".join([
                    f'<b style="color:red;">Error: {e}</b>' 
                    for e in errors
                ])
                display(HTML(error_html))
                return None, 'no'
            
            if cleaned_df is not None and len(cleaned_df) > 0:
                # Add information about remaining rows
                display(HTML(
                    f'<b style="color:green;">Processed data contains {len(cleaned_df)} rows '
                    f'after removing blank values.</b>'
                ))
                return cleaned_df, 'yes'
            else:
                display(HTML('<b style="color:red;">Error: No valid data rows remaining after cleaning</b>'))
                return None, 'no'
                
        except Exception as e:
            display(HTML(f'<b style="color:red;">Error processing file: {str(e)}</b>'))
            return None, 'no'
    
    def _on_merged_upload_change(self, change):
        """Handle merged data file upload"""
        if change['type'] == 'change' and change['name'] == 'value':
            with self.output_area:
                self.output_area.clear_output()
                if change['new'] and len(change['new']) > 0:
                    file_data = change['new'][0]
                    df, status = self._load_merged_data(file_data)
                    if status == 'yes' and df is not None:
                        self.merged_df = df  # Only set merged_df if validation passed
                        display(HTML(
                            f'<b style="color:green;">Merged data imported successfully with '
                            f'{df.shape[0]} rows and {df.shape[1]} columns.</b>'
                        ))

    def _find_species(self, header):
        """Find species in FASTA header"""
        header_lower = header.lower()
        for spec_group in spec_translate_list:
            for term in spec_group[1:]:
                if term.lower() in header_lower:
                    return spec_group[0]
        return "unknown"

    def _parse_uploaded_fasta(self, file_data):
        """Parse uploaded FASTA file content"""
        fasta_dict = {}
        fasta_text = bytes(file_data.content).decode('utf-8')
        lines = fasta_text.split('\n')
        
        protein_id = ""
        protein_name = ""
        sequence = ""
        species = ""
        
        for line in lines:
            line = line.strip()
            if line.startswith('>'):
                if protein_id:
                    fasta_dict[protein_id] = {
                        "name": protein_name,
                        "sequence": sequence,
                        "species": species
                    }
                sequence = ""
                header_parts = line[1:].split('|')
                if len(header_parts) > 2:
                    protein_id = header_parts[1]
                    protein_name_full = re.split(r' OS=', header_parts[2])[0]
                    protein_name = protein_name_full if ' ' in protein_name_full else protein_name_full
                    species = self._find_species(line)
            else:
                sequence += line
                
        if protein_id:
            fasta_dict[protein_id] = {
                "name": protein_name,
                "sequence": sequence,
                "species": species
            }
        
        return fasta_dict


In [57]:
class ProteinPlotter:
    def __init__(self, data_transformer):
        self.data_transformer = data_transformer
        self.plot_output = widgets.Output()
        self.info_output = widgets.Output()
        self.export_output = widgets.Output()
        self.proteins_df = None
        self.sum_df = None
        
        # Create widgets for protein plotting
        self.num_proteins_widget = widgets.IntSlider(
            value=10,
            min=1,
            max=50,
            step=1,
            description='Top Number of Proteins:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )
        
        # Create multi-select widget for groups
        self.group_select = widgets.SelectMultiple(
            options=[],
            description='Select Groups:',
            style={'description_width': 'initial'},

            layout=widgets.Layout(width='400px', height='100px')
        )
        
        self.plot_button = widgets.Button(
            description='Generate Plot',
            layout=widgets.Layout(width='200px')
        )
        
        self.export_button = widgets.Button(
            description='Export Data',
            layout=widgets.Layout(width='200px')
        )
        # Add label customization widgets
        self.xlabel_widget = widgets.Text(
            description='X Label:',
            placeholder='Enter x-axis label',
            layout=widgets.Layout(width='400px')
        )
        
        self.ylabel_widget = widgets.Text(
            description='Y Label:',
            placeholder='Enter y-axis label',
            layout=widgets.Layout(width='400px')
        )

        self.color_scheme = widgets.Dropdown(
            options=[
                # Sequential
                ('Viridis', 'Viridis'),
                ('Cividis', 'Cividis'),
                ('Inferno', 'Inferno'),
                ('Magma', 'Magma'),
                ('Plasma', 'Plasma'),
                ('Warm', 'Warm'),
                ('Cool', 'Cool'),
                ('Hot', 'Hot'),
                ('Jet', 'Jet'),
                # Sequential (Blues)
                ('Blues', 'Blues'),
                ('Bluered', 'Bluered'),
                ('Blugrn', 'Blugrn'),
                # Sequential (Greens)
                ('Greens', 'Greens'),
                ('Gnbu', 'GnBu'),
                # Sequential (Purples)
                ('Purples', 'Purples'),
                ('Pubu', 'PuBu'),
                ('Purd', 'PuRd'),
                ('Purp', 'Purp'),
                # Sequential (Oranges/Reds)
                ('Oranges', 'Oranges'),
                ('Reds', 'Reds'),
                ('Orrd', 'OrRd'),
                # Diverging
                ('Spectral', 'Spectral'),
                ('RdBu', 'RdBu'),
                ('RdYlBu', 'RdYlBu'),
                ('RdYlGn', 'RdYlGn'),
                ('PiYG', 'PiYG'),
                ('PRGn', 'PRGn'),
                ('BrBG', 'BrBG'),
                ('RdGy', 'RdGy'),
                # Cyclical
                ('Rainbow', 'Rainbow'),
                ('IceFire', 'IceFire'),
                ('Edge', 'Edge'),
                ('HSV', 'HSV'),
                ('Twilight', 'Twilight'),
                ('Mrybm', 'Mrybm'),
                ('Mygbm', 'Mygbm'),
            ],
            value='HSV',
            description='Color Scheme:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )

        # Create layout
        self.widget_box = widgets.VBox([
            widgets.HTML("<h3>Protein Absorbance Analysis</h3>"),
            self.num_proteins_widget,
            self.group_select,
            self.xlabel_widget,
            self.ylabel_widget,
            self.color_scheme,
            widgets.HBox([self.plot_button, self.export_button]),
            self.info_output,
            self.plot_output,
            self.export_output
        ])

        # Add button click handlers
        self.plot_button.on_click(self._on_plot_button_click)
        self.export_button.on_click(self._on_export_button_click)

        # Add observer for data changes
        self.data_transformer.group_uploader.observe(self._update_group_options, names='value')

    def process_data(self, selected_groups=None):
        """Process data with optional group selection"""
        if not all([self.data_transformer.merged_df is not None,
                   self.data_transformer.group_data != {},
                   self.data_transformer.proteins_dic is not None]):
            return False

        df = self.data_transformer.merged_df.copy()
        
        # Get Absorbance columns based on selected groups
        if selected_groups:
            Absorbance_cols = [f'Avg_{var}' for var in selected_groups]
        else:
            Absorbance_cols = [col for col in df.columns if 'Avg_' in col]
            
        df['Total_Absorbance'] = df[Absorbance_cols].sum(axis=1).astype(int)
        
        # Filter out zero Absorbance entries
        result_df = df[['unique ID', 'Total_Absorbance']]
        result_df = result_df[result_df['Total_Absorbance'] == 0]
        all_zero_list = list(result_df['unique ID'])
        peptides_df = df[~df['unique ID'].isin(all_zero_list)]

        # Process protein positions and create proteins DataFrame
        additional_columns = ['Master Protein Accessions', 'Positions in Proteins', 'unique ID']
        selected_columns = additional_columns + Absorbance_cols
        
        peptides_df.loc[:, 'Positions in Proteins'] = peptides_df['Positions in Proteins'].apply(
            lambda x: re.sub(r'\[\d+-\d+\]', '', x).replace(';', ',').strip(',').strip()
        )
        
        temp_df = peptides_df.copy()
        temp_df.loc[:, 'Protein_ID'] = temp_df['Positions in Proteins']
        
        # Create proteins DataFrame with selected columns
        self.proteins_df = temp_df.groupby('Protein_ID').agg(
            {**{col: 'first' for col in ['Master Protein Accessions']},
             **{col: 'sum' for col in Absorbance_cols}}
        ).reset_index()
        
        # Calculate relative Absorbance for selected groups
        for col in Absorbance_cols:
            col_sum = self.proteins_df[col].sum()
            if col_sum > 0:  # Avoid division by zero
                self.proteins_df[f'Rel_{col}'] = (self.proteins_df[col] / col_sum) * 100
            else:
                self.proteins_df[f'Rel_{col}'] = 0
            
        # Create sum DataFrame for selected groups
        self.sum_df = pd.DataFrame({
            'Sample': Absorbance_cols,
            'Total_Sum': [self.proteins_df[col].sum() for col in Absorbance_cols]
        })
        
        # Add protein descriptions
        name_list = []
        for _, row in self.proteins_df.iterrows():
            if ',' in row['Protein_ID']:
                strrow = row['Protein_ID'].split(',')
                named_combo = self._fetch_protein_names('; '.join(strrow))
            else:
                named_combo = self._fetch_protein_names(row['Protein_ID'])
            name_list.append(named_combo)
        
        # Drop the 'Protein_ID' column
        self.proteins_df = self.proteins_df.drop(columns=['Protein_ID'])    
        
        self.proteins_df['Description'] = name_list
        self.proteins_df['Description'] = self.proteins_df['Description'].astype(str).str.replace(r"['\['\]]", "", regex=True)
        
        return True

    def _fetch_protein_names(self, accession_str):
        names = []
        for acc in accession_str.split('; '):
            if acc in self.data_transformer.proteins_dic:
                name = self.data_transformer.proteins_dic[acc]['name'].split()[1]
                names.append(name)
            else:
                names.append(acc)
        return names

    def _update_group_options(self, change):
        """Update group selection options when data changes"""
        if self.data_transformer.group_data is not None:
            grouping_vars = [group_info['grouping_variable'] 
                           for group_info in self.data_transformer.group_data.values()]
            self.group_select.options = grouping_vars
            # Select all groups by default
            self.group_select.value = grouping_vars
         
    def plot_stacked_bar_scaled(self, pro_list, title, selected_groups):
            if self.proteins_df is None or self.sum_df is None:
                return None
                
            scaled_df = self.proteins_df.copy()
            
            # Filter for selected groups
            sample_orders = [f'Rel_Avg_{var}' for var in selected_groups]
            
            # Create sample mapping for selected groups
            sample_mapping = {
                f'Rel_Avg_{var}': f'Avg_{var}' for var in selected_groups
            }
            
            # Convert relative absorbance columns to percentages
            for col in sample_orders:
                scaled_df[col] = self.proteins_df[col] * 100
                
            # Scale absolute Absorbance
            for col in sample_orders:
                sample_key = sample_mapping[col]
                total_sum = self.sum_df.loc[self.sum_df['Sample'] == sample_key, 'Total_Sum'].values[0]
                if total_sum > 0:
                    scaled_df[col] = self.proteins_df[col] * total_sum / self.proteins_df[col].sum()
            
            # Sort proteins_df based on pro_list
            description_order = {desc: i for i, desc in enumerate(pro_list)}
            scaled_df['Order'] = scaled_df['Description'].map(description_order)
            scaled_df = scaled_df.sort_values(by='Order').reset_index(drop=True)
            
            fig = go.Figure()
                
            def get_color_sequence(scheme, n_colors):
                """Get color sequence based on selected scheme."""
                try:
                    # Get the color sequence from plotly's built-in sequences
                    if scheme.lower() in ['rainbow', 'hsv']:
                        return [f'hsl({h},70%,60%)' for h in np.linspace(0, 330, n_colors)]
                    
                    # Get the corresponding plotly color sequence
                    color_sequence = getattr(px.colors.sequential, scheme, None)
                    if color_sequence is None:
                        color_sequence = getattr(px.colors.diverging, scheme, None)
                    if color_sequence is None:
                        color_sequence = getattr(px.colors.cyclical, scheme, None)
                    
                    if color_sequence:
                        # Ensure we get the right number of colors
                        if n_colors >= len(color_sequence):
                            # If we need more colors than available, interpolate
                            indices = np.linspace(0, len(color_sequence)-1, n_colors)
                            return [color_sequence[int(i)] for i in indices]
                        else:
                            # If we need fewer colors, take evenly spaced samples
                            indices = np.linspace(0, len(color_sequence)-1, n_colors, dtype=int)
                            return [color_sequence[i] for i in indices]
                    
                    # Fallback to rainbow if scheme not found
                    return [f'hsl({h},70%,60%)' for h in np.linspace(0, 330, n_colors)]
                
                except Exception as e:
                    print(f"Error generating colors: {e}")
                    # Fallback to basic rainbow
                    return [f'hsl({h},70%,60%)' for h in np.linspace(0, 330, n_colors)]
    
            # Get colors based on selected scheme
            colors = get_color_sequence(self.color_scheme.value, len(pro_list))
            
            for idx, row in scaled_df.iterrows():
                protein_description = row['Description']
                if protein_description in pro_list:
                    color = colors[pro_list.index(protein_description)]
                    
                    hover_text = []
                    for sample in sample_orders:
                        abs_col = sample_mapping[sample]
                        rel_value = self.proteins_df.loc[idx, sample]
                        rel_value_hov = self.proteins_df.loc[self.proteins_df['Description'] == protein_description, sample].values[0]
                        abs_value = row[abs_col]
                        hover_text.append(
                            f"Protein: {row['Master Protein Accessions']}<br>" +
                            f"Description: {row['Description']}<br>" +
                            f"Sample: {sample.replace('Rel_Avg_', '')}<br>" +
                            f"Relative Absorbance: {rel_value_hov:.2f}%<br>" +
                            f"Absolute Absorbance: {abs_value:.2e}"
                        )
                    
                    fig.add_trace(go.Bar(
                        name=protein_description,
                        x=[label.replace('Rel_Avg_', '') for label in sample_orders],
                        y=row[sample_orders],
                        marker_color=color,
                        hovertext=hover_text,
                        hoverinfo='text'
                    ))
    
            # Get custom labels
            x_label = self.xlabel_widget.value or 'Sample Type'
            y_label = self.ylabel_widget.value or 'Scaled Absolute Absorbance'
            
            fig.update_layout(
                barmode='stack',
                title={
                    'text': title,
                    'y': 0.95,
                    'x': 0.5,
                    'xanchor': 'center',
                    'yanchor': 'top'
                },
                xaxis_title=x_label,
                yaxis_title=y_label,
                legend_title="Protein",
                legend={'yanchor': "top", 'y': 1, 'xanchor': "left", 'x': 1.05},
                showlegend=True,
                template='plotly_white',
                height=600,
                width=1000,
                margin=dict(t=100, l=100, r=200),
                hoverlabel=dict(
                    bgcolor="white",
                    font_size=12,
                    font_family="Arial"
                )
            )
            
            fig.update_xaxes(
                tickangle=45,
                title_font={"size": 14},
                tickfont={"size": 12}
            )
            
            fig.update_yaxes(
                title_font={"size": 14},
                tickfont={"size": 12},
                exponentformat='E',
                showexponent='all'
            )
            
            return fig


    def _on_plot_button_click(self, b):
        with self.plot_output:
            self.plot_output.clear_output(wait=True)
            
            if not self.group_select.value:
                print("Please select at least one group to plot.")
                return
            
            selected_groups = list(self.group_select.value)
            if self.process_data(selected_groups):
                # Calculate average Absorbance for sorting using only selected groups
                selected_rel_columns = [f'Rel_Avg_{var}' for var in selected_groups]
                self.proteins_df['avg_absorbance_all'] = self.proteins_df[selected_rel_columns].mean(axis=1)
                
                # Sort and get top N proteins
                sorted_proteins = self.proteins_df.sort_values('avg_absorbance_all', ascending=False)
                pro_list = list(sorted_proteins['Description'].head(self.num_proteins_widget.value))
                
                # Create and display plot
                fig = self.plot_stacked_bar_scaled(
                    pro_list=pro_list,
                    title=f'',#Top {self.num_proteins_widget.value} Proteins - Relative Absorbance',
                    selected_groups=selected_groups
                )
                if fig is not None:
                    fig.show()
            else:
                print("Please upload all required files first.")
                print("Error creating plot. Please check your data.")

    def _on_export_button_click(self, b):
        with self.export_output:
            self.export_output.clear_output(wait=True)
            
            if self.proteins_df is not None:
                # Export to CSV
                csv_data = self.proteins_df.to_csv(index=False).encode('utf-8')
                b64 = base64.b64encode(csv_data).decode()
                timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
                filename = f'protein_Absorbance_analysis_{timestamp}.csv'
                
                # Create download link
                html_str = f'''
                    <a download="{filename}" 
                       href="data:text/csv;base64,{b64}" 
                       class="download-link" 
                       style="background-color: #4CAF50;
                              border: none;
                              color: white;
                              padding: 10px 20px;
                              text-align: center;
                              text-decoration: none;
                              display: inline-block;
                              font-size: 14px;
                              margin: 4px 2px;
                              cursor: pointer;
                              border-radius: 4px;">
                        Download {filename}
                    </a>
                '''
                display(HTML(html_str))
            else:
                print("Please generate the analysis first.")

    def display(self):
        """Display the protein analysis interface"""
        display(self.widget_box)


# Initialize the interface
data_transformer = DataTransformation()
data_transformer.setup_data_loading_ui()

# Create protein plotter
protein_plotter = ProteinPlotter(data_transformer)
protein_plotter.display()

GridBox(children=(VBox(children=(HTML(value='<h3><u>Upload Data Files:</u></h3>'), HBox(children=(FileUpload(v…

VBox(children=(HTML(value='<h3>Protein Absorbance Analysis</h3>'), IntSlider(value=10, description='Top Number…

In [41]:
df = protein_plotter.proteins_df

In [42]:
df[df['Description'] ==  'Alpha-lactalbumin']

Unnamed: 0,Master Protein Accessions,Avg_Acid_Low_DH,Avg_Sweet_Low_DH1,Avg_Sweet_Low_DH2,Avg_Acid_Mid_DH,Avg_Sweet_Mid_DH,Avg_Sweet_High_DH1,Avg_Sweet_High_DH2,Rel_Avg_Acid_Low_DH,Rel_Avg_Sweet_Low_DH1,Rel_Avg_Sweet_Low_DH2,Rel_Avg_Acid_Mid_DH,Rel_Avg_Sweet_Mid_DH,Rel_Avg_Sweet_High_DH1,Rel_Avg_Sweet_High_DH2,Description,avg_absorbance_all
13,P00711,19165450000.0,27579900000.0,15248290000.0,13056000000.0,7224882000.0,4097130000.0,2626658000.0,8.710411,6.81001,6.557491,5.061462,6.077251,13.023648,12.397124,Alpha-lactalbumin,8.376771
14,P00711,3083917000.0,5912715000.0,1043788000.0,637183900.0,68757200.0,122615100.0,101804200.0,1.401594,1.459964,0.448878,0.247019,0.057836,0.38976,0.480489,Alpha-lactalbumin,0.640791


In [43]:

# Print data types for all columns
alpha_df = df[df['Description'] == 'Alpha-lactalbumin']
print("Data types for Alpha-lactalbumin rows:")
print("=====================================")
for column, dtype in alpha_df.dtypes.items():
    print(f"{column}: {dtype}")

# Print some example values to verify
print("\nSample values from first row:")
print("=============================")
first_row = alpha_df.iloc[0]
for column, value in first_row.items():
    print(f"{column}: {value} (type: {type(value).__name__})")


Data types for Alpha-lactalbumin rows:
Master Protein Accessions: object
Avg_Acid_Low_DH: float64
Avg_Sweet_Low_DH1: float64
Avg_Sweet_Low_DH2: float64
Avg_Acid_Mid_DH: float64
Avg_Sweet_Mid_DH: float64
Avg_Sweet_High_DH1: float64
Avg_Sweet_High_DH2: float64
Rel_Avg_Acid_Low_DH: float64
Rel_Avg_Sweet_Low_DH1: float64
Rel_Avg_Sweet_Low_DH2: float64
Rel_Avg_Acid_Mid_DH: float64
Rel_Avg_Sweet_Mid_DH: float64
Rel_Avg_Sweet_High_DH1: float64
Rel_Avg_Sweet_High_DH2: float64
Description: object
avg_absorbance_all: float64

Sample values from first row:
Master Protein Accessions: P00711 (type: str)
Avg_Acid_Low_DH: 19165450188.351444 (type: float64)
Avg_Sweet_Low_DH1: 27579897816.42547 (type: float64)
Avg_Sweet_Low_DH2: 15248293760.257385 (type: float64)
Avg_Acid_Mid_DH: 13056001540.235628 (type: float64)
Avg_Sweet_Mid_DH: 7224882374.315495 (type: float64)
Avg_Sweet_High_DH1: 4097130192.282995 (type: float64)
Avg_Sweet_High_DH2: 2626657854.991533 (type: float64)
Rel_Avg_Acid_Low_DH: 8.71041098