In [1]:
import pandas as pd
import numpy as np
from datetime import datetime
import json, io, base64, re, os, tempfile, zipfile
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
from itertools import combinations


In [2]:
class DataTransformation:
    def __init__(self):
        self.merged_df = None
        self.group_data = None
        self.info_area = None
        self.merged_uploader = None
        self.group_uploader = None
        self.reset_button = None
        
    def create_download_link(self, file_path, label):
        """Create a download link for a file."""
        if os.path.exists(file_path):
            # Read file content and encode it as base64
            with open(file_path, 'rb') as f:
                content = f.read()
            b64_content = base64.b64encode(content).decode('utf-8')

            # Generate the download link HTML
            return widgets.HTML(f"""
                <a download="{os.path.basename(file_path)}" 
                   href="data:application/octet-stream;base64,{b64_content}" 
                   style="color: #0366d6; text-decoration: none; margin-left: 20px; font-size: 14px;">
                    {label}
                </a>
            """)
        else:
            # Show an error message if the file does not exist
            return widgets.HTML(f"""
                <span style="color: red; margin-left: 20px; font-size: 14px;">
                    File "{file_path}" not found!
                </span>
            """)    
            
    def setup_data_loading_ui(self):
        """Initialize and display the data loading UI"""
        # Create file upload widgets
        self.merged_uploader = widgets.FileUpload(
            accept='.csv,.txt,.tsv,.xlsx',
            multiple=False,
            description='Upload Merged Data File',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )
        
        self.group_uploader = widgets.FileUpload(
            accept='.json',
            multiple=False,
            description='Upload Group Definition',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )

        self.info_area = widgets.Output()

       # Create individual upload boxes with example links
        merged_box = widgets.HBox([
            self.merged_uploader,
            self.create_download_link("example_merged_dataframe.csv", "Example")
        ], layout=widgets.Layout(align_items='center'))

        group_box = widgets.HBox([
            self.group_uploader,
            self.create_download_link("example_group_definition.json", "Example")
            
        ], layout=widgets.Layout(align_items='center'))
        
        # Create upload section
        upload_widgets = widgets.VBox([
            widgets.HTML("<h4>Upload Data Files:</h4>"),
            merged_box,
           widgets.HTML("<h4>Upload Group Definition:</h4>"),
            group_box,
            self.info_area
        ], layout=widgets.Layout(
            width='300px',
            margin='0 20px 0 0'
        ))
        
        # Create status area
        self.status_area = widgets.Output()
        

        display(upload_widgets,
                self.status_area)
        
        # Register observers
        self.merged_uploader.observe(self._on_merged_upload_change, names='value')
        self.group_uploader.observe(self._on_group_upload_change, names='value')


    def _on_merged_upload_change(self, change):
        if change['type'] == 'change' and change['name'] == 'value':
            with self.info_area:
                self.info_area.clear_output()
                uploaded_files = change.get('new', ())
                if uploaded_files:
                    # Get the first uploaded file
                    file_data = uploaded_files[0]  # Access first element of tuple
                    self.merged_df = self._load_merged_data(file_data)
                    if self.merged_df is not None:
                        display(HTML(
                            f'<b style="color:green;">Merged data imported: '
                            f'{self.merged_df.shape[0]} rows, {self.merged_df.shape[1]} columns</b>'
                        ))

    def _on_group_upload_change(self, change):
        if change['type'] == 'change' and change['name'] == 'value':
            with self.info_area:
                self.info_area.clear_output()
                uploaded_files = change.get('new', ())
                if uploaded_files:
                    # Get the first uploaded file
                    file_data = uploaded_files[0]  # Access first element of tuple
                    try:
                        content = bytes(file_data.content).decode('utf-8')
                        group_data = json.loads(content)
                        self.group_data = self._process_group_data(group_data)
                        display(HTML(
                            f'<b style="color:green;">Group definition imported successfully with {len(self.group_data)} groups.</b><br>' 
                        ))
                    except Exception as e:
                        display(HTML(f'<b style="color:red;">Error loading group definition: {str(e)}</b>'))

    def _load_merged_data(self, file_data):
        try:
            content = bytes(file_data.content)
            filename = file_data.name
            extension = filename.split('.')[-1].lower()
            
            file_stream = io.BytesIO(content)
            
            if extension == 'csv':
                df = pd.read_csv(file_stream)
            elif extension in ['txt', 'tsv']:
                df = pd.read_csv(file_stream, delimiter='\t')
            elif extension == 'xlsx':
                df = pd.read_excel(file_stream)
            else:
                raise ValueError("Unsupported file format")
            
            # Validate required columns
            required_columns = ['Master Protein Accessions', 'Positions in Proteins']
            missing = set(required_columns) - set(df.columns)
            if missing:
                raise ValueError(f"Missing required columns: {', '.join(missing)}")
                
            return df
            
        except Exception as e:
            display(HTML(f'<b style="color:red;">Error loading data: {str(e)}</b>'))
            return None

    def _process_group_data(self, json_data):
        """Process and validate the group data structure"""
        try:
            processed_data = {}
            for group_id, group_info in json_data.items():
                # Validate required fields
                if 'grouping_variable' not in group_info:
                    raise ValueError(f"Group {group_id} missing grouping_variable")
                if 'abundance_columns' not in group_info:
                    raise ValueError(f"Group {group_id} missing abundance_columns")

                # Create standardized group entry
                processed_data[group_id] = {
                    'grouping_variable': group_info['grouping_variable'],
                    'abundance_columns': group_info['abundance_columns']
                }

            return processed_data
        except Exception as e:
            raise ValueError(f"Error processing group data: {str(e)}")


In [3]:
class CorrelationPlotter:
    def __init__(self, data_transformer):
        self.data_transformer = data_transformer
        self.plot_output = widgets.Output()
        self.info_area = widgets.Output()
        self.export_output = widgets.Output()

        # Initialize function options
        self.function_options = self._get_function_options()

        # Create generate plot button
        self.plot_button = widgets.Button(
            description='Generate/Update Data',
            button_style='success',
            icon='refresh',
            layout=widgets.Layout(width='300px')
        )
        
        # Initialize dropdowns and other widgets as before
        self.group1_dropdown = widgets.Dropdown(
            description='Group 1:',
            options=[],
            layout=widgets.Layout(width='300px')
        )

        self.group2_dropdown = widgets.Dropdown(
            description='Group 2:',
            options=[],
            layout=widgets.Layout(width='300px')
        )

        self.correlation_type = widgets.Dropdown(
            options=['Pearson', 'Spearman'],
            description='Correlation:',
            value='Pearson',
            layout=widgets.Layout(width='300px')
        )

        self.log_transform = widgets.Checkbox(
            value=True,
            description='Log10 transform data',
            layout=widgets.Layout(width='300px')
        )

        # Function selection with colors
        self.function1_widget = widgets.Dropdown(
            options=self.function_options,
            description='Function 1:',
            value='All Peptides',
            layout=widgets.Layout(width='300px')
        )

        self.function1_color = widgets.ColorPicker(
            concise=False,
            description='Color:',
            value='#0072C6',
            layout=widgets.Layout(width='200px')
        )

        self.function2_widget = widgets.Dropdown(
            options=self.function_options,
            description='Function 2:',
            value='All Peptides',
            layout=widgets.Layout(width='300px')
        )

        self.function2_color = widgets.ColorPicker(
            concise=False,
            description='Color:',
            value='#0072C6',
            layout=widgets.Layout(width='200px')
        )

        # Add label customization widgets
        self.xlabel_widget = widgets.Text(
            description='X Label:',
            placeholder='Enter x-axis label',
            layout=widgets.Layout(width='300px')
        )

        self.ylabel_widget = widgets.Text(
            description='Y Label:',
            placeholder='Enter y-axis label',
            layout=widgets.Layout(width='300px')
        )
       
        # Create batch export button with matching style
        self.batch_export_button = widgets.Button(
            description='Export All Plot Combinations',
            button_style='info',
            icon='download',
            layout=widgets.Layout(width='300px'),
            disabled=True
        )   
        # Download HTML button
        self.download_html_button = widgets.Button(
            description='Download Interactive Plots',
            button_style='info',
            icon='file',
            layout=widgets.Layout(width='300px'),
            disabled=True
        )   
    
        self.export_group_correlation_button = widgets.Button(
            description='Export Sample-to-Sample Correlations',
            button_style='info',
            icon='download',
            layout=widgets.Layout(width='300px'),
            disabled=True
            )   
        
        self.export_replicate_correlation_button = widgets.Button(
            description='Export Technical Replicate Correlations',
            button_style='info',
            icon='download',
            layout=widgets.Layout(width='300px'),
            disabled=True
            )               
        self.plot_button.on_click(self._on_plot_button_click)
        self.batch_export_button.on_click(self._handle_batch_export)
        self.download_html_button.on_click(self._on_download_html_click)
        self.export_group_correlation_button.on_click(self._on_export_group_correlation_click)
        self.export_replicate_correlation_button.on_click(self._on_export_replicate_correlation_click)
        
        # Create function selection boxes with colors
        self.function1_box = widgets.HBox([
            self.function1_widget,
            self.function1_color
        ])

        self.function2_box = widgets.HBox([
            self.function2_widget,
            self.function2_color
        ])

        # Create layout with all widgets including the generate plot button
        self.widget_box = widgets.VBox([
            widgets.HTML("<h4>Plot Controls:</h4>"),
            self.group1_dropdown,
            self.group2_dropdown,
            self.correlation_type,
            self.log_transform,
            self.function1_box,
            self.function2_box,
            widgets.HTML("<h4>Appearance Settings:</h4>"),
            self.xlabel_widget,
            self.ylabel_widget,
            widgets.HTML("<h4>Actions:</h4>"),
            self.plot_button,
            self.batch_export_button,
            self.download_html_button,
            self.export_group_correlation_button,
            self.export_replicate_correlation_button,
            self.info_area,
            self.export_output,
            self.plot_output
        ])

        # Watch for data changes only
        self.data_transformer.merged_uploader.observe(self._on_data_change, names='value')
        self.data_transformer.group_uploader.observe(self._on_data_change, names='value')


    def _on_plot_button_click(self, b):
        """Handle generate plot button click"""
        with self.plot_output:
            self.plot_output.clear_output(wait=True)

            if self.group1_dropdown.value == self.group2_dropdown.value:
                display(HTML('<b style="color:red;">Please select different groups for comparison.</b>'))
                return

            fig = self.create_correlation_plot(
                self.group1_dropdown.value,
                self.group2_dropdown.value
            )

            if fig is not None:
                fig.show(config={
                    'displayModeBar': True,
                    'scrollZoom': True,
                    'modeBarButtonsToRemove': ['select2d', 'lasso2d']
                })
                self.batch_export_button.disabled = False
                self.export_group_correlation_button.disabled = False
                self.export_replicate_correlation_button.disabled = False
                self.download_html_button.disabled = False                
            else:
                display(HTML('<b style="color:red;">No data available for the selected combination.</b>'))

    # Add this new method to the class:
    def _download_plot(self, b):
        """Download the current plot as HTML"""
        if self.group1_dropdown.value and self.group2_dropdown.value:
            fig = self.create_correlation_plot(
                self.group1_dropdown.value,
                self.group2_dropdown.value
            )

            if fig is not None:
                timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
                filename = f"correlation_plot_{timestamp}.html"

                # Create download link
                html_str = fig.to_html()
                b64 = base64.b64encode(html_str.encode()).decode()
                href = f'<a download="{filename}" href="data:text/html;base64,{b64}" target="_blank">Click to download plot</a>'

                with self.plot_output:
                    display(HTML(href))

    def _normalize_column_names(self):
        """Normalize column names in the dataframe"""
        if self.data_transformer.merged_df is not None:
            # Create a copy of the dataframe to avoid modifying the original
            df = self.data_transformer.merged_df.copy()

            # Rename 'Function' to 'function' if it exists
            if 'Function' in df.columns:
                df = df.rename(columns={'Function': 'function'})

            # Update the merged_df in data_transformer
            self.data_transformer.merged_df = df

    def _get_function_options(self):
        """Extract and process function options from the data"""
        if self.data_transformer.merged_df is None:
            return ['All Peptides']

        # Check if function column exists
        has_function = 'function' in self.data_transformer.merged_df.columns

        if not has_function:
            return ['All Peptides']  # Return only 'All Peptides' if no function column exists

        # Extract unique function components and count occurrences
        function_counts = {}

        # Track if we have any functional proteins
        has_functional = False

        for func in self.data_transformer.merged_df['function'].dropna():
            has_functional = True
            # Split by both comma and semicolon
            components = [f.strip() for f in func.replace(';', ',').split(',')]
            for component in components:
                if component:  # Only count non-empty components
                    function_counts[component] = function_counts.get(component, 0) + 1

        # Filter for functions with more than 1 occurrence
        valid_functions = [func for func, count in function_counts.items() if count > 1]

        # Build options list
        options = ['All Peptides', 'Non-Functional']
        if has_functional:
            options.append('Functional')

        # Add individual functions if they exist
        if valid_functions:
            options.extend(sorted(valid_functions))

        return options

    def _setup_observers(self):
        """Setup observers for all interactive widgets"""
        widgets_to_observe = [
            self.group1_dropdown,
            self.group2_dropdown,
            self.correlation_type,
            self.log_transform,
            self.function1_widget,
            self.function2_widget,
            self.function1_color,
            self.function2_color,
            self.xlabel_widget,
            self.ylabel_widget
        ]
        for widget in widgets_to_observe:
            widget.observe(self._on_widget_change, names='value')

    def _filter_data_by_function(self, df, function_filter):
        """Filter dataframe based on function selection"""
        if 'function' in df.columns:
            if function_filter == 'Non-Functional':
                return df[df['function'].isna()]
            elif function_filter == 'Functional':
                return df[df['function'].notna()]
            elif function_filter != 'All Peptides':
                return df[df['function'].fillna('').str.contains(
                    function_filter,
                    case=False,
                    na=False
                )]
        return df

    def create_correlation_plot(self, group1, group2):
        """Generate correlation plot with function filtering and custom labels"""
        if self.data_transformer.merged_df is None or self.data_transformer.group_data is None:
            return None

        df = self.data_transformer.merged_df.copy()
        legend_names = set()

        col1 = f"Avg_{group1}"
        col2 = f"Avg_{group2}"

        if col1 not in df.columns or col2 not in df.columns:
            return None

        # Create figure
        fig = go.Figure()

        # Process each function selection
        for function_filter, color_picker in [
            (self.function1_widget.value, self.function1_color),
            (self.function2_widget.value, self.function2_color)
        ]:
            # Filter data
            filtered_df = self._filter_data_by_function(df.copy(), function_filter)
            filtered_df = filtered_df[(filtered_df[col1] > 0) & (filtered_df[col2] > 0)].copy()

            if len(filtered_df) == 0:
                continue

            # Apply log transformation if selected
            if self.log_transform.value:
                x_values = np.log10(filtered_df[col1])
                y_values = np.log10(filtered_df[col2])
                x_label_prefix = "Log10 "
                y_label_prefix = "Log10 "
                ef = '.2f'
                tickformater = 'f'
                xaxislabel = f'{x_label_prefix}({group1})'
                yaxislabel = f'{x_label_prefix}({group2})'

            else:
                x_values = filtered_df[col1]
                y_values = filtered_df[col2]
                x_label_prefix = ""
                y_label_prefix = ""
                ef = '.1e'
                tickformater = ef
                xaxislabel = f'{group1}'
                yaxislabel = f'{group2}'
            # Calculate correlation
            if len(filtered_df) > 1:
                if self.correlation_type.value == 'Pearson':
                    corr, _ = pearsonr(x_values, y_values)
                else:  # Spearman
                    corr, _ = spearmanr(x_values, y_values)
                correlation_text = f'r = {corr:.3f}'
            else:
                correlation_text = 'n/a'

            # Add scatter points
            hover_data = [filtered_df['unique ID']]
            if 'function' in filtered_df.columns:
                hover_data.append(filtered_df['function'].fillna('Non-Functional'))
            else:
                hover_data.append(['N/A'] * len(filtered_df))

            # Check if this function name is already in legend
            show_in_legend = function_filter not in legend_names
            legend_names.add(function_filter)

            # Add scatter points with conditional legend
            fig.add_trace(go.Scatter(
                x=x_values,
                y=y_values,
                mode='markers',
                name=f'{function_filter} ({correlation_text})',
                marker=dict(color=color_picker.value),
                showlegend=show_in_legend,
                legendgrouptitle_font=dict(size=16),  # Increase legend title font size
                legendgroup=function_filter,  # Group traces with the same function
                hovertemplate=(
                    '<b>Peptide ID:</b> %{customdata[0]}<br>' +
                    '<b>Function:</b> %{customdata[1]}<br>' +
                    f'<b>{group1}:</b> %{{x:{ef}}}<br>' +
                    f'<b>{group2}:</b> %{{y:{ef}}}<br>' +
                    '<extra></extra>'
                ),
                customdata=np.column_stack(hover_data)
            ))

            # Add trendline if we have enough points
            if len(filtered_df) > 1:
                z = np.polyfit(x_values, y_values, 1)
                x_range = np.linspace(x_values.min(), x_values.max(), 100)
                fig.add_trace(go.Scatter(
                    x=x_range,
                    y=np.poly1d(z)(x_range),
                    mode='lines',
                    line=dict(color=color_picker.value, dash='dash'),
                    name=f'{function_filter} trendline',
                    showlegend=False,  # This line removes it from legend
                    hovertemplate='<extra></extra>'
                ))

        if not fig.data:
            return None

        # Get custom labels or use defaults
        xlabel = self.xlabel_widget.value or xaxislabel
        ylabel = self.ylabel_widget.value or yaxislabel

        # Update axis formatting
        fig.update_xaxes(
            title_font={"size": 14},
            tickfont={"size": 12},
            tickformat=tickformater
        )

        fig.update_yaxes(
            title_font={"size": 14},
            tickfont={"size": 12},
            tickformat=tickformater
        )

        # Update layout
        fig.update_layout(
            title=dict(
                text=f'{self.correlation_type.value} Correlation',
                x=0.5,
                xanchor='center'
            ),
            xaxis_title=xlabel,
            yaxis_title=ylabel,
            width=800,
            height=800,
            template='simple_white',
            hoverlabel=dict(
                bgcolor="white",
                font_size=14,
                font_family="Arial"
            ),
            legend=dict(
                yanchor="top",
                y=1.05,
                xanchor="right",
                x=0.99,
                bgcolor="rgba(255, 255, 255, 0.8)"
            )
        )

        # Make aspect ratio equal
        fig.update_layout(yaxis=dict(scaleanchor="x", scaleratio=1))

        return fig

    def _update_plot(self):
        """Update plot based on current widget values"""
        with self.plot_output:
            self.plot_output.clear_output(wait=True)

            if self.group1_dropdown.value == self.group2_dropdown.value:
                display(HTML('<b style="color:red;">Please select different groups for comparison.</b>'))
                return

            fig = self.create_correlation_plot(
                self.group1_dropdown.value,
                self.group2_dropdown.value
            )

            if fig is not None:
                fig.show(config={
                    'displayModeBar': True,
                    'scrollZoom': True,
                    'modeBarButtonsToRemove': ['select2d', 'lasso2d']
                })
            else:
                display(HTML('<b style="color:red;">No data available for the selected combination.</b>'))

    def _on_widget_change(self, change):
        """Handle any widget value changes"""
        if all([self.group1_dropdown.value, self.group2_dropdown.value]):
            self._update_plot()

    def _on_data_change(self, change):
        """Update widgets when data changes"""
        if self.data_transformer.merged_df is not None:
            self._normalize_column_names()

        if self.data_transformer.merged_df is not None and self.data_transformer.group_data is not None:
            # Update group dropdowns
            group_vars = [info['grouping_variable'] for info in self.data_transformer.group_data.values()]
            self.group1_dropdown.options = group_vars
            self.group2_dropdown.options = group_vars

            # Update function options
            options = self._get_function_options()
            self.function1_widget.options = options
            self.function2_widget.options = options

    def _handle_batch_export(self, b):
        """Handle batch export of correlation plots as PNGs in a ZIP file"""
        if self.data_transformer.group_data is None:
            return
    
        # Clear existing output first
        with self.export_output:
            self.export_output.clear_output(wait=True)
            display(HTML("Generating plots... Please wait."))
    
        # Get all group variables
        group_vars = [info['grouping_variable'] 
                     for info in self.data_transformer.group_data.values()]
        
        # Generate all possible pairs
        pairs = [(group1, group2) 
                for i, group1 in enumerate(group_vars) 
                for group2 in group_vars[i+1:]]
    
        # Create ZIP file in memory
        zip_buffer = io.BytesIO()
        with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
            # Generate plots for each pair
            for group1, group2 in pairs:
                # Create plot
                fig = self.create_correlation_plot(group1, group2)
                
                if fig is not None:
                    # Generate filename
                    correlation_type = self.correlation_type.value.lower()
                    transform_type = "log10" if self.log_transform.value else "raw"
                    filename = f"correlation_{group1}_vs_{group2}_{correlation_type}_{transform_type}.png"
                    
                    # Save plot as PNG to memory
                    img_buffer = io.BytesIO()
                    fig.write_image(img_buffer, format='png', width=800, height=800)
                    img_buffer.seek(0)
                    
                    # Add to ZIP
                    zip_file.writestr(filename, img_buffer.getvalue())
    
        # Get the ZIP content and encode it
        zip_buffer.seek(0)
        b64_zip = base64.b64encode(zip_buffer.getvalue()).decode()
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        filename = f"correlation_plots_{timestamp}.zip"
    
        # Create and display the download link
        with self.export_output:
            self.export_output.clear_output(wait=True)
            display(HTML(f'''
                <div id="batch_export_{timestamp}">
                    <a id="batch_link_{timestamp}" 
                       href="data:application/zip;base64,{b64_zip}" 
                       download="{filename}"
                       style="display: none;"></a>
                    <script>
                        (function() {{
                            const link = document.getElementById('batch_link_{timestamp}');
                            if (link) {{
                                link.click();
                                setTimeout(() => {{
                                    const container = document.getElementById('batch_export_{timestamp}');
                                    if (container) container.remove();
                                }}, 1000);
                            }}
                        }})();
                    </script>
                    <div style="color: green; padding: 10px;">
                        Exported {len(pairs)} plots to {filename}
                    </div>
                </div>
            '''))


    def calculate_correlation(self, x, y):
        """Calculate correlation based on selected method"""
        if  self.correlation_type.value == 'Pearson':
            return pearsonr(x, y)[0]
        else:  # Spearman
            return spearmanr(x, y)[0]

    def prepare_data(self, data):
        """Prepare data based on log transform setting"""
        if  self.log_transform.value:
            return np.log10(data)
        return data
    
    def _export_group_correlation_analysis(self, df, group_data):
        """
        Calculate and export correlation analysis to Excel.
        Returns bytes of Excel file content.
        """
        try:
            # Calculate cross-group correlations
            correlation_results = []
            avg_columns = {
                group_info['grouping_variable']: f"Avg_{group_info['grouping_variable']}"
                for group_info in group_data.values()
                if f"Avg_{group_info['grouping_variable']}" in df.columns
            }
            
            # Create Excel writer buffer
            buffer = io.BytesIO()
            with pd.ExcelWriter(buffer, engine='openpyxl') as writer:
                # Cross-group correlations
                for (group1, col1), (group2, col2) in combinations(avg_columns.items(), 2):
                    mask = (df[col1] > 0) & (df[col2] > 0)
                    if mask.sum() > 1:
                        values1 = self.prepare_data(df.loc[mask, col1])
                        values2 = self.prepare_data(df.loc[mask, col2])
                        correlation = self.calculate_correlation(values1, values2)
                        correlation_results.append({
                            'Group 1': group1,
                            'Group 2': group2,
                            'Correlation': round(correlation, 3),
                            'Number of Peptides': mask.sum()
                        })
                
                # Create and write cross-group correlation sheet
                if correlation_results:
                    cross_group_df = pd.DataFrame(correlation_results)
                    cross_group_df.to_excel(writer, sheet_name='Cross-Group Correlations', index=False)
                
                # Calculate summary statistics
                summary_stats = {
                    'Average Correlation': round(np.mean([r['Correlation'] for r in correlation_results]), 3),
                    'Min Correlation': round(min([r['Correlation'] for r in correlation_results]), 3),
                    'Max Correlation': round(max([r['Correlation'] for r in correlation_results]), 3),
                    'Total Comparisons': len(correlation_results)
                }
                
                # Write summary statistics
                pd.DataFrame([summary_stats]).to_excel(writer, sheet_name='Summary', index=False)
    
            return buffer.getvalue()
            
        except Exception as e:
            raise Exception(f"Error in group correlation analysis: {str(e)}")
    
    def _on_export_group_correlation_click(self, b):
        """Handle correlation export button click"""
        try:
            with self.info_area:
                self.info_area.clear_output(wait=True)
                
                if self.data_transformer.merged_df is None or self.data_transformer.group_data is None:
                    display(HTML('<div style="color: red; padding: 10px;">No data available for correlation analysis.</div>'))
                    return
                
                # Get the Excel content
                excel_content = self._export_group_correlation_analysis(
                    self.data_transformer.merged_df,
                    self.data_transformer.group_data
                )
                
                if excel_content is None:
                    display(HTML('<div style="color: red; padding: 10px;">No correlation data generated.</div>'))
                    return
                    
                # Generate filename
                timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
                correlation_type = self.correlation_type.value.lower()
                transform_type = "log10" if self.log_transform.value else "raw"
                filename = f"group_correlations_{correlation_type}_{transform_type}_{timestamp}.xlsx"
                
                # Create base64 encoded string for download
                b64_content = base64.b64encode(excel_content).decode()
                
                # Create download link with error handling
                download_html = f'''
                    <div id="correlation_export_{timestamp}">
                        <a id="correlation_link_{timestamp}" 
                           href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64_content}" 
                           download="{filename}"
                           style="display: none;"></a>
                        <script>
                            try {{
                                document.getElementById('correlation_link_{timestamp}').click();
                                console.log("Download initiated successfully");
                            }} catch(e) {{
                                console.error("Error initiating download:", e);
                            }}
                        </script>
                    </div>
                '''
                
                display(HTML(download_html))
                
        except Exception as e:
            with self.info_area:
                self.info_area.clear_output(wait=True)
                display(HTML(f'<div style="color: red; padding: 10px;">Error exporting correlations: {str(e)}</div>'))
    
    def _export_replicate_correlation_analysis(self, df, group_data):
        """
        Calculate and export replicate correlation analysis to Excel.
        Returns bytes of Excel file content.
        """
        try:
            # Calculate within-group correlations
            within_group_correlations = {}
            for key, value in group_data.items():
                grouping_variable = value['grouping_variable']
                abundance_columns = value['abundance_columns']
                
                data = df[abundance_columns].copy()
                data = data[data.gt(0).all(axis=1)]  # Filter for rows where all values > 0
                
                if len(data) > 1:
                    data = self.prepare_data(data)
                    
                    # Calculate correlation matrix
                    method = 'pearson' if self.correlation_type.value == 'Pearson' else 'spearman'
                    correlation_matrix = data.corr(method=method)
                    
                    # Get lower triangle only to avoid redundancy
                    lower_triangle = correlation_matrix.where(
                        np.tril(np.ones(correlation_matrix.shape), k=-1).astype(bool)
                    )
                    
                    # Create pairs and get correlation values
                    pairs = []
                    values = []
                    for i in range(len(abundance_columns)):
                        for j in range(i):
                            pair_name = f"{abundance_columns[j]} vs {abundance_columns[i]}"
                            pairs.append(pair_name)
                            values.append(round(lower_triangle.iloc[i,j], 3))
                    
                    within_group_correlations[grouping_variable] = pd.Series(values)
    
            # Create Excel file
            buffer = io.BytesIO()
            with pd.ExcelWriter(buffer, engine='openpyxl') as writer:
                if within_group_correlations:
                    # Create summary sheet with all groups
                    combined_correlation_df = pd.concat(within_group_correlations, axis=1)
                    
                    # Calculate summary statistics
                    min_values = combined_correlation_df.min().round(3)
                    max_values = combined_correlation_df.max().round(3)
                    mean_values = combined_correlation_df.mean().round(3)
                    
                    summary_df = pd.DataFrame({
                        'Min': min_values,
                        'Max': max_values,
                        'Average': mean_values
                    }).T
                    
                    # Combine and write summary
                    combined_with_summary = pd.concat([combined_correlation_df, summary_df], axis=0)
                    combined_with_summary.to_excel(writer, sheet_name='Summary')
                    
                    # Create individual sheets for each group
                    for key, value in group_data.items():
                        grouping_variable = value['grouping_variable']
                        if grouping_variable in within_group_correlations:
                            values = within_group_correlations[grouping_variable]
                            pairs = []
                            for i in range(len(value['abundance_columns'])):
                                for j in range(i):
                                    pairs.append(f"{value['abundance_columns'][j]} vs {value['abundance_columns'][i]}")
                            
                            group_df = pd.DataFrame({
                                'Pair': pairs,
                                'Correlation': values
                            })
                            group_df.to_excel(writer, sheet_name=grouping_variable, index=False)
            
            return buffer.getvalue()
            
        except Exception as e:
            raise Exception(f"Error in replicate correlation analysis: {str(e)}")
            
    def _on_export_replicate_correlation_click(self, b):
        """Handle correlation export button click"""
        try:
            with self.info_area:
                self.info_area.clear_output(wait=True)
                
                if self.data_transformer.merged_df is None or self.data_transformer.group_data is None:
                    display(HTML('<div style="color: red; padding: 10px;">No data available for correlation analysis.</div>'))
                    return
                
                # Get the Excel content
                excel_content = self._export_replicate_correlation_analysis(
                    self.data_transformer.merged_df,
                    self.data_transformer.group_data
                )
                
                # Generate filename
                timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
                correlation_type = self.correlation_type.value.lower()
                transform_type = "log10" if self.log_transform.value else "raw"
                filename = f"correlation_analysis_{correlation_type}_{transform_type}_{timestamp}.xlsx"
                
                # Create download link
                display(HTML(f'''
                    <div id="correlation_export_{timestamp}">
                        <a id="correlation_link_{timestamp}" 
                           href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{base64.b64encode(excel_content).decode()}" 
                           download="{filename}"
                           style="display: none;"></a>
                    </div>
                    <script>
                        document.getElementById('correlation_link_{timestamp}').click();
                    </script>
                '''))
                
        except Exception as e:
            with self.info_area:
                self.info_area.clear_output(wait=True)
                display(HTML(f'<div style="color: red; padding: 10px;">Error exporting correlations: {str(e)}</div>'))

    def _on_download_html_click(self, b):
        """Generate and download interactive HTML plot"""
        try:
            # Clear existing output first
            with self.export_output:
                self.export_output.clear_output(wait=True)
                
            if not (self.group1_dropdown.value and self.group2_dropdown.value):
                with self.export_output:
                    display(HTML('<div style="color: red; padding: 10px;">Please select two groups to compare first.</div>'))
                return
                
            fig = self.create_correlation_plot(
                self.group1_dropdown.value,
                self.group2_dropdown.value
            )
            
            if fig is None:
                with self.export_output:
                    display(HTML('<div style="color: red; padding: 10px;">No data available for the selected groups.</div>'))
                return
                
            # Generate timestamp and filename
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            correlation_type = self.correlation_type.value.lower()
            transform_type = "log10" if self.log_transform.value else "raw"
            filename = f"interactive_correlation_plot_{timestamp}.html"
            
            # Create HTML content
            html_content = '''
            <html>
            <head>
                <title>Interactive Correlation Plot</title>
                <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
            </head>
            <body>
            '''
            
            # Add plot
            div_id = f'correlation_plot_{timestamp}'
            html_content += f'<div id="{div_id}" style="width: 1000px; height: 700px;"></div>\n'
            html_content += f'<script>{fig.to_json()}</script>\n'
            html_content += f'''
            <script>
                Plotly.newPlot("{div_id}", {fig.to_json()});
            </script>
            </body>
            </html>
            '''
            
            # Create and display download link
            with self.export_output:
                display(HTML(f'''
                    <div id="html_export_{timestamp}">
                        <a id="html_link_{timestamp}" 
                           href="data:text/html;charset=utf-8;base64,{base64.b64encode(html_content.encode()).decode()}" 
                           download="{filename}"
                           style="display: none;"></a>
                        <script>
                            (function() {{
                                const link = document.getElementById('html_link_{timestamp}');
                                if (link) {{
                                    link.click();
                                    setTimeout(() => {{
                                        const container = document.getElementById('html_export_{timestamp}');
                                        if (container) container.remove();
                                    }}, 1000);
                                }}
                            }})();
                        </script>
                        <div style="color: green; padding: 10px;">
                            Successfully generated interactive plot: {filename}
                        </div>
                    </div>
                '''))
                
        except Exception as e:
            with self.export_output:
                self.export_output.clear_output(wait=True)
                display(HTML(f'<div style="color: red; padding: 10px;">Error generating interactive plot: {str(e)}</div>'))

    def display(self):
        """Display the correlation analysis interface"""
        display(self.widget_box)

In [4]:
# Main execution code remains mostly the same, just remove the reactive observers
data_transformer = DataTransformation()
data_transformer.setup_data_loading_ui()

correlation_plotter = CorrelationPlotter(data_transformer)

# Display components
correlation_plotter.display()

VBox(children=(HTML(value='<h4>Upload Data Files:</h4>'), HBox(children=(FileUpload(value=(), accept='.csv,.tx…

Output()

VBox(children=(HTML(value='<h4>Plot Controls:</h4>'), Dropdown(description='Group 1:', layout=Layout(width='30…