In [1]:
import pandas as pd
import numpy as np
from datetime import datetime
import json, io, base64, re, os, tempfile, zipfile
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
from itertools import combinations


In [2]:
class DataTransformation:
    def __init__(self):
        self.merged_df = None
        self.group_data = None
        self.output_area = None
        self.merged_uploader = None
        self.group_uploader = None
        self.reset_button = None
        
    def create_download_link(self, file_path, label):
        """Create a download link for a file."""
        if os.path.exists(file_path):
            # Read file content and encode it as base64
            with open(file_path, 'rb') as f:
                content = f.read()
            b64_content = base64.b64encode(content).decode('utf-8')

            # Generate the download link HTML
            return widgets.HTML(f"""
                <a download="{os.path.basename(file_path)}" 
                   href="data:application/octet-stream;base64,{b64_content}" 
                   style="color: #0366d6; text-decoration: none; margin-left: 20px; font-size: 14px;">
                    {label}
                </a>
            """)
        else:
            # Show an error message if the file does not exist
            return widgets.HTML(f"""
                <span style="color: red; margin-left: 20px; font-size: 14px;">
                    File "{file_path}" not found!
                </span>
            """)    
            
    def setup_data_loading_ui(self):
        """Initialize and display the data loading UI"""
        # Create file upload widgets
        self.merged_uploader = widgets.FileUpload(
            accept='.csv,.txt,.tsv,.xlsx',
            multiple=False,
            description='Upload Merged Data File',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )
        
        self.group_uploader = widgets.FileUpload(
            accept='.json',
            multiple=False,
            description='Upload Group Definition',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )

        self.output_area = widgets.Output()

       # Create individual upload boxes with example links
        merged_box = widgets.HBox([
            self.merged_uploader,
            self.create_download_link("example_merged_dataframe.csv", "Example")
        ], layout=widgets.Layout(align_items='center'))

        group_box = widgets.HBox([
            self.group_uploader,
            self.create_download_link("example_group_definition.json", "Example")
            
        ], layout=widgets.Layout(align_items='center'))
        
        # Create upload section
        upload_widgets = widgets.VBox([
            widgets.HTML("<h4>Upload Data Files:</h4>"),
            merged_box,
           widgets.HTML("<h4>Upload Group Definition:</h4>"),
            group_box,
            self.output_area
        ])
        
        # Create status area
        self.status_area = widgets.Output()
        
        # Create grid layout
        grid = widgets.GridBox(
            [upload_widgets, self.status_area],
            layout=widgets.Layout(
                grid_template_columns='auto auto',
                grid_gap='20px',
                width='100%',
                overflow_x='hidden'
            )
        )
        
        # Register observers
        self.merged_uploader.observe(self._on_merged_upload_change, names='value')
        self.group_uploader.observe(self._on_group_upload_change, names='value')

        # Display the grid
        display(grid)

    def _on_merged_upload_change(self, change):
        if change['type'] == 'change' and change['name'] == 'value':
            with self.output_area:
                self.output_area.clear_output()
                uploaded_files = change.get('new', ())
                if uploaded_files:
                    # Get the first uploaded file
                    file_data = uploaded_files[0]  # Access first element of tuple
                    self.merged_df = self._load_merged_data(file_data)
                    if self.merged_df is not None:
                        display(HTML(
                            f'<b style="color:green;">Merged data imported: '
                            f'{self.merged_df.shape[0]} rows, {self.merged_df.shape[1]} columns</b>'
                        ))

    def _on_group_upload_change(self, change):
        if change['type'] == 'change' and change['name'] == 'value':
            with self.output_area:
                self.output_area.clear_output()
                uploaded_files = change.get('new', ())
                if uploaded_files:
                    # Get the first uploaded file
                    file_data = uploaded_files[0]  # Access first element of tuple
                    try:
                        content = bytes(file_data.content).decode('utf-8')
                        group_data = json.loads(content)
                        self.group_data = self._process_group_data(group_data)
                        display(HTML(
                            f'<b style="color:green;">Group definition imported successfully with {len(self.group_data)} groups.</b><br>' 
                        ))
                    except Exception as e:
                        display(HTML(f'<b style="color:red;">Error loading group definition: {str(e)}</b>'))

    def _load_merged_data(self, file_data):
        try:
            content = bytes(file_data.content)
            filename = file_data.name
            extension = filename.split('.')[-1].lower()
            
            file_stream = io.BytesIO(content)
            
            if extension == 'csv':
                df = pd.read_csv(file_stream)
            elif extension in ['txt', 'tsv']:
                df = pd.read_csv(file_stream, delimiter='\t')
            elif extension == 'xlsx':
                df = pd.read_excel(file_stream)
            else:
                raise ValueError("Unsupported file format")
            
            # Validate required columns
            required_columns = ['Master Protein Accessions', 'Positions in Proteins']
            missing = set(required_columns) - set(df.columns)
            if missing:
                raise ValueError(f"Missing required columns: {', '.join(missing)}")
                
            return df
            
        except Exception as e:
            display(HTML(f'<b style="color:red;">Error loading data: {str(e)}</b>'))
            return None

    def _process_group_data(self, json_data):
        """Process and validate the group data structure"""
        try:
            processed_data = {}
            for group_id, group_info in json_data.items():
                # Validate required fields
                if 'grouping_variable' not in group_info:
                    raise ValueError(f"Group {group_id} missing grouping_variable")
                if 'abundance_columns' not in group_info:
                    raise ValueError(f"Group {group_id} missing abundance_columns")

                # Create standardized group entry
                processed_data[group_id] = {
                    'grouping_variable': group_info['grouping_variable'],
                    'abundance_columns': group_info['abundance_columns']
                }

            return processed_data
        except Exception as e:
            raise ValueError(f"Error processing group data: {str(e)}")


In [3]:
class CorrelationPlotter:
    def __init__(self, data_transformer):
        self.data_transformer = data_transformer
        self.plot_output = widgets.Output()
        self.info_output = widgets.Output()
        self.export_output = widgets.Output()

        # Initialize function options
        self.function_options = self._get_function_options()

        self.download_button = widgets.Button(
            description='Download Plot as HTML',
            layout=widgets.Layout(width='200px'),
            button_style='info',
            icon='save'
        )
        self.download_button.on_click(self._download_plot)
        
        # Add the button to your widget_box layout list right after the plot_output:
        self.widget_box = widgets.VBox([
            # ... existing widgets ...
            self.plot_output,
            self.download_button
        ])
        # Create widgets for correlation analysis
        self.group1_dropdown = widgets.Dropdown(
            description='Group 1:',
            options=[],
            layout=widgets.Layout(width='300px')
        )
        
        self.group2_dropdown = widgets.Dropdown(
            description='Group 2:',
            options=[],
            layout=widgets.Layout(width='300px')
        )
        
        # Add correlation type selection
        self.correlation_type = widgets.Dropdown(
            options=['Pearson', 'Spearman'],
            description='Correlation:',
            value='Pearson',
            layout=widgets.Layout(width='300px')
        )

        # Add log transformation option
        self.log_transform = widgets.Checkbox(
            value=True,
            description='Log10 transform data',
            layout=widgets.Layout(width='300px')
        )
        
        # Function selection with colors
        self.function1_widget = widgets.Dropdown(
            options=self.function_options,
            description='Function 1:',
            value='All Peptides',
            layout=widgets.Layout(width='300px')
        )
        
        self.function1_color = widgets.ColorPicker(
            concise=False,
            description='Color:',
            value='#0072C6',
            layout=widgets.Layout(width='200px')
        )

        self.function2_widget = widgets.Dropdown(
            options=self.function_options,
            description='Function 2:',
            value='All Peptides',
            layout=widgets.Layout(width='300px')
        )
        
        self.function2_color = widgets.ColorPicker(
            concise=False,
            description='Color:',
            value='#0072C6',
            layout=widgets.Layout(width='200px')
        )

        # Add label customization widgets
        self.xlabel_widget = widgets.Text(
            description='X Label:',
            placeholder='Enter x-axis label',
            layout=widgets.Layout(width='300px')
        )
        
        self.ylabel_widget = widgets.Text(
            description='Y Label:',
            placeholder='Enter y-axis label',
            layout=widgets.Layout(width='300px')
        )

        # Create function selection boxes with colors
        self.function1_box = widgets.HBox([
            self.function1_widget,
            self.function1_color
        ])
        
        self.function2_box = widgets.HBox([
            self.function2_widget,
            self.function2_color
        ])

        # Create layout with all widgets
        self.widget_box = widgets.VBox([
            widgets.HTML("<h4>Plot Controls:</h4>"),
            self.group1_dropdown,
            self.group2_dropdown,
            self.correlation_type,
            self.log_transform,
            self.function1_box,
            self.function2_box,
            widgets.HTML("<h4>Appearance Settings:</h4>"),
            self.xlabel_widget,
            self.ylabel_widget,
            widgets.HTML("<h4>Actions:</h4>"),
            
            
            self.info_output,
            self.export_output,
            self.plot_output
        ])

        # Watch for changes
        self.data_transformer.merged_uploader.observe(self._on_data_change, names='value')
        self.data_transformer.group_uploader.observe(self._on_data_change, names='value')
        
        # Add observers for all widgets
        self._setup_observers()
     
    # First, create a method to set up data transformer observers
    def setup_data_transformer_observers(self):
        """Set up observers for data transformer after it's initialized"""
        if hasattr(self.data_transformer, 'merged_uploader') and self.data_transformer.merged_uploader is not None:
            self.data_transformer.merged_uploader.observe(self._on_data_change, names='value')
        if hasattr(self.data_transformer, 'group_uploader') and self.data_transformer.group_uploader is not None:
            self.data_transformer.group_uploader.observe(self._on_data_change, names='value')
   
    # Add this new method to the class:
    def _download_plot(self, b):
        """Download the current plot as HTML"""
        if self.group1_dropdown.value and self.group2_dropdown.value:
            fig = self.create_correlation_plot(
                self.group1_dropdown.value,
                self.group2_dropdown.value
            )
            
            if fig is not None:
                timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
                filename = f"correlation_plot_{timestamp}.html"
                
                # Create download link
                html_str = fig.to_html()
                b64 = base64.b64encode(html_str.encode()).decode()
                href = f'<a download="{filename}" href="data:text/html;base64,{b64}" target="_blank">Click to download plot</a>'
                
                with self.plot_output:
                    display(HTML(href))
    def _normalize_column_names(self):
        """Normalize column names in the dataframe"""
        if self.data_transformer.merged_df is not None:
            # Create a copy of the dataframe to avoid modifying the original
            df = self.data_transformer.merged_df.copy()

            # Rename 'Function' to 'function' if it exists
            if 'Function' in df.columns:
                df = df.rename(columns={'Function': 'function'})

            # Update the merged_df in data_transformer
            self.data_transformer.merged_df = df

    
    def _get_function_options(self):
        """Extract and process function options from the data"""
        if self.data_transformer.merged_df is None:
            return ['All Peptides']

        # Check if function column exists
        has_function = 'function' in self.data_transformer.merged_df.columns

        if not has_function:
            return ['All Peptides']  # Return only 'All Peptides' if no function column exists

        # Extract unique function components and count occurrences
        function_counts = {}
        
        # Track if we have any functional proteins
        has_functional = False
        
        for func in self.data_transformer.merged_df['function'].dropna():
            has_functional = True
            # Split by both comma and semicolon
            components = [f.strip() for f in func.replace(';', ',').split(',')]
            for component in components:
                if component:  # Only count non-empty components
                    function_counts[component] = function_counts.get(component, 0) + 1
        
        # Filter for functions with more than 1 occurrence
        valid_functions = [func for func, count in function_counts.items() if count > 1]
        
        # Build options list
        options = ['All Peptides', 'Non-Functional']
        if has_functional:
            options.append('Functional')
        
        # Add individual functions if they exist
        if valid_functions:
            options.extend(sorted(valid_functions))
            
        return options

    def _setup_observers(self):
        """Setup observers for all interactive widgets"""
        widgets_to_observe = [
            self.group1_dropdown,
            self.group2_dropdown,
            self.correlation_type,
            self.log_transform,
            self.function1_widget,
            self.function2_widget,
            self.function1_color,
            self.function2_color,
            self.xlabel_widget,
            self.ylabel_widget
        ]
        for widget in widgets_to_observe:
            widget.observe(self._on_widget_change, names='value')

    def _filter_data_by_function(self, df, function_filter):
        """Filter dataframe based on function selection"""
        if 'function' in df.columns:
            if function_filter == 'Non-Functional':
                return df[df['function'].isna()]
            elif function_filter == 'Functional':
                return df[df['function'].notna()]
            elif function_filter != 'All Peptides':
                return df[df['function'].fillna('').str.contains(
                    function_filter, 
                    case=False, 
                    na=False
                )]
        return df

    def create_correlation_plot(self, group1, group2):
        """Generate correlation plot with function filtering and custom labels"""
        if self.data_transformer.merged_df is None or self.data_transformer.group_data is None:
            return None

        df = self.data_transformer.merged_df.copy()
        legend_names = set()
        
        col1 = f"Avg_{group1}"
        col2 = f"Avg_{group2}"
        
        if col1 not in df.columns or col2 not in df.columns:
            return None

        # Create figure
        fig = go.Figure()
        
        # Process each function selection
        for function_filter, color_picker in [
            (self.function1_widget.value, self.function1_color),
            (self.function2_widget.value, self.function2_color)
        ]:
            # Filter data
            filtered_df = self._filter_data_by_function(df.copy(), function_filter)
            filtered_df = filtered_df[(filtered_df[col1] > 0) & (filtered_df[col2] > 0)].copy()
            
            if len(filtered_df) == 0:
                continue
                
            # Apply log transformation if selected
            if self.log_transform.value:
                x_values = np.log10(filtered_df[col1])
                y_values = np.log10(filtered_df[col2])
                x_label_prefix = "Log10 "
                y_label_prefix = "Log10 "
                ef = '.2f'
                tickformater = 'f'
                xaxislabel = f'{x_label_prefix}({group1})'
                yaxislabel = f'{x_label_prefix}({group2})'

            else:
                x_values = filtered_df[col1]
                y_values = filtered_df[col2]
                x_label_prefix = ""
                y_label_prefix = ""
                ef = '.1e'
                tickformater = ef
                xaxislabel = f'{group1}'
                yaxislabel = f'{group2}'
            # Calculate correlation
            if len(filtered_df) > 1:
                if self.correlation_type.value == 'Pearson':
                    corr, _ = pearsonr(x_values, y_values)
                else:  # Spearman
                    corr, _ = spearmanr(x_values, y_values)
                correlation_text = f'r = {corr:.3f}'
            else:
                correlation_text = 'n/a'

            # Add scatter points
            hover_data = [filtered_df['unique ID']]
            if 'function' in filtered_df.columns:
                hover_data.append(filtered_df['function'].fillna('Non-Functional'))
            else:
                hover_data.append(['N/A'] * len(filtered_df))
        
            # Check if this function name is already in legend
            show_in_legend = function_filter not in legend_names
            legend_names.add(function_filter)
        
            # Add scatter points with conditional legend
            fig.add_trace(go.Scatter(
                x=x_values,
                y=y_values,
                mode='markers',
                name=f'{function_filter} ({correlation_text})',
                marker=dict(color=color_picker.value),
                showlegend=show_in_legend,  # Only show in legend if not already there
                hovertemplate=(
                    '<b>Peptide ID:</b> %{customdata[0]}<br>' +
                    '<b>Function:</b> %{customdata[1]}<br>' +
                    f'<b>{group1}:</b> %{{x:{ef}}}<br>' +
                    f'<b>{group2}:</b> %{{y:{ef}}}<br>' +
                    '<extra></extra>'
                ),
                customdata=np.column_stack(hover_data)
            ))

            # Add trendline if we have enough points
            if len(filtered_df) > 1:
                z = np.polyfit(x_values, y_values, 1)
                x_range = np.linspace(x_values.min(), x_values.max(), 100)
                fig.add_trace(go.Scatter(
                    x=x_range,
                    y=np.poly1d(z)(x_range),
                    mode='lines',
                    line=dict(color=color_picker.value, dash='dash'),
                    name=f'{function_filter} trendline',
                    showlegend=False,  # This line removes it from legend
                    hovertemplate='<extra></extra>'
                ))

        if not fig.data:
            return None

        # Get custom labels or use defaults
        xlabel = self.xlabel_widget.value or xaxislabel
        ylabel = self.ylabel_widget.value or yaxislabel

        # Update axis formatting
        fig.update_xaxes(
            title_font={"size": 14},
            tickfont={"size": 12},
            tickformat=tickformater
        )
        
        fig.update_yaxes(
            title_font={"size": 14},
            tickfont={"size": 12},
            tickformat=tickformater
        )

        # Update layout
        fig.update_layout(
            title=dict(
                text=f'{self.correlation_type.value} Correlation',
                x=0.5,
                xanchor='center'
            ),
            xaxis_title=xlabel,
            yaxis_title=ylabel,
            width=800,
            height=800,
            template='simple_white',
            hoverlabel=dict(
                bgcolor="white",
                font_size=14,
                font_family="Arial"
            ),
            legend=dict(
                yanchor="top",
                y=1.05,
                xanchor="right",
                x=0.99,
                bgcolor="rgba(255, 255, 255, 0.8)"
            )
        )

        # Make aspect ratio equal
        fig.update_layout(yaxis=dict(scaleanchor="x", scaleratio=1))

        return fig

    def _update_plot(self):
        """Update plot based on current widget values"""
        with self.plot_output:
            self.plot_output.clear_output(wait=True)
            
            if self.group1_dropdown.value == self.group2_dropdown.value:
                display(HTML('<b style="color:red;">Please select different groups for comparison.</b>'))
                return
                
            fig = self.create_correlation_plot(
                self.group1_dropdown.value,
                self.group2_dropdown.value
            )
            
            if fig is not None:
                fig.show(config={
                    'displayModeBar': True,
                    'scrollZoom': True,
                    'modeBarButtonsToRemove': ['select2d', 'lasso2d']
                })
            else:
                display(HTML('<b style="color:red;">No data available for the selected combination.</b>'))
    
    def _on_widget_change(self, change):
        """Handle any widget value changes"""
        if all([self.group1_dropdown.value, self.group2_dropdown.value]):
            self._update_plot()

    def _on_data_change(self, change):
        """Update widgets when data changes"""
        if self.data_transformer.merged_df is not None:
            self._normalize_column_names()
            
        if self.data_transformer.merged_df is not None and self.data_transformer.group_data is not None:
            # Update group dropdowns
            group_vars = [info['grouping_variable'] for info in self.data_transformer.group_data.values()]
            self.group1_dropdown.options = group_vars
            self.group2_dropdown.options = group_vars
            
            # Update function options
            options = self._get_function_options()
            self.function1_widget.options = options
            self.function2_widget.options = options

    def display(self):
        """Display the correlation analysis interface"""
        display(self.widget_box)

In [4]:
class CorrelationExportHandler:
    def __init__(self, data_transformer, correlation_plotter):
        self.data_transformer = data_transformer
        self.correlation_plotter = correlation_plotter
        self.export_output = widgets.Output()
        
        # Create batch export button with matching style
        self.batch_export_button = widgets.Button(
            description='Export All Combinations as ZIP',
            layout=widgets.Layout(
                width='300px',
                height='40px',
                margin='0',
                display='none'  # Hide by default
            ),
            style={
                'button_color': '#4682B4',
                'text_color': 'blue',
                'font_size': '14px'
            }
        )
        self.batch_export_button.on_click(self._handle_batch_export)
        
        # Define CSS style
        self.style = """
            <style>
            .download-link {
                background-color: #4CAF50;
                border: none;
                color: white;
                padding: 10px 20px;
                text-align: center;
                text-decoration: none;
                display: inline-block;
                font-size: 14px;
                margin: 4px 2px;
                cursor: pointer;
                border-radius: 4px;
            }
            .download-link:hover {
                background-color: #45a049;
            }
            .export-section {
                margin-bottom: 20px;
                padding: 15px;
                border-radius: 5px;
                background-color: #f8f9fa;
            }
            .export-description {
                color: #666;
                margin: 5px 0 15px 0;
                font-style: italic;
            }
            .hidden {
                display: none !important;
            }
            </style>
        """
        
        # Create container for batch export section
        self.batch_export_container = widgets.HTML()
        
        # Set up observers for data changes
        self.data_transformer.merged_uploader.observe(self._on_data_change, names='value')
        self.data_transformer.group_uploader.observe(self._on_data_change, names='value')

        

    def _handle_batch_export(self, b):
        """Handle batch export of correlation plots as PNGs in a ZIP file"""
        if self.data_transformer.group_data is None:
            return

        with self.export_output:
            self.export_output.clear_output(wait=True)
            display(HTML("Generating plots... Please wait."))

        # Get all group variables
        group_vars = [info['grouping_variable'] 
                     for info in self.data_transformer.group_data.values()]
        
        # Generate all possible pairs
        pairs = [(group1, group2) 
                for i, group1 in enumerate(group_vars) 
                for group2 in group_vars[i+1:]]

        # Create ZIP file in memory
        zip_buffer = io.BytesIO()
        with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
            # Generate plots for each pair
            for group1, group2 in pairs:
                # Create plot
                fig = self.correlation_plotter.create_correlation_plot(group1, group2)
                
                if fig is not None:
                    # Generate filename
                    correlation_type = self.correlation_plotter.correlation_type.value.lower()
                    transform_type = "log10" if self.correlation_plotter.log_transform.value else "raw"
                    filename = f"correlation_{group1}_vs_{group2}_{correlation_type}_{transform_type}.png"
                    
                    # Save plot as PNG to memory
                    img_buffer = io.BytesIO()
                    fig.write_image(img_buffer, format='png', width=800, height=800)
                    img_buffer.seek(0)
                    
                    # Add to ZIP
                    zip_file.writestr(filename, img_buffer.getvalue())

        # Get the ZIP content and encode it
        zip_buffer.seek(0)
        b64_zip = base64.b64encode(zip_buffer.getvalue()).decode()
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        filename = f"correlation_plots_{timestamp}.zip"

        # Create and execute the download script
        download_script = f"""
        <script>
        (function() {{
            const link = document.createElement('a');
            link.href = 'data:application/zip;base64,{b64_zip}';
            link.download = '{filename}';
            document.body.appendChild(link);
            link.click();
            document.body.removeChild(link);
        }})();
        </script>
        """

        with self.export_output:
            self.export_output.clear_output(wait=True)
            display(HTML(f"Exported {len(pairs)} plots to {filename}"))
            display(HTML(download_script))
            self.generate_export_links()
    def generate_download_link(self, content, filename, filetype='text/csv'):
        """Generate a download link for any content"""
        if isinstance(content, pd.DataFrame):
            if filetype == 'text/csv':
                content = content.to_csv(index=False)
            else:
                content = content.to_csv(index=True)
        if isinstance(content, str):
            content = content.encode()
        b64 = base64.b64encode(content).decode()
        return f"""
            <a download="{filename}" href="data:{filetype};base64,{b64}" class="download-link" 
               title="Click to download">
                Download {filename}
            </a>
        """

    def calculate_correlation(self, x, y):
        """Calculate correlation based on selected method"""
        if self.correlation_plotter.correlation_type.value == 'Pearson':
            return pearsonr(x, y)[0]
        else:  # Spearman
            return spearmanr(x, y)[0]

    def prepare_data(self, data):
        """Prepare data based on log transform setting"""
        if self.correlation_plotter.log_transform.value:
            return np.log10(data)
        return data

    def generate_export_links(self):
        """Generate and display export links for correlation data"""
        if self.data_transformer.merged_df is None or self.data_transformer.group_data is None:
            return

        df = self.data_transformer.merged_df.copy()
        group_data = self.data_transformer.group_data.copy()

        # Calculate cross-group correlations
        correlation_results = []
        Avg_columns = {
            group_info['grouping_variable']: f"Avg_{group_info['grouping_variable']}"
            for group_info in group_data.values()
            if f"Avg_{group_info['grouping_variable']}" in df.columns
        }
        
        for (group1, col1), (group2, col2) in combinations(Avg_columns.items(), 2):
            mask = (df[col1] > 0) & (df[col2] > 0)
            if mask.sum() > 1:
                values1 = self.prepare_data(df.loc[mask, col1])
                values2 = self.prepare_data(df.loc[mask, col2])
                correlation = self.calculate_correlation(values1, values2)
                correlation = round(correlation, 3)
                correlation_results.append((group1, group2, correlation))
        
        cross_group_correlations = pd.DataFrame(
            correlation_results, 
            columns=['Group 1', 'Group 2', 'Correlation']
        )

        # Calculate within-group correlations
        within_group_correlations = {}
        for key, value in group_data.items():
            grouping_variable = value['grouping_variable']
            abundance_columns = value['abundance_columns']
            
            data = df[abundance_columns].copy()
            data = data[data.gt(0).all(axis=1)]
            
            if len(data) > 1:
                data = self.prepare_data(data)
                
                # Calculate correlation matrix
                method = 'pearson' if self.correlation_plotter.correlation_type.value == 'Pearson' else 'spearman'
                correlation_matrix = data.corr(method=method)
                
                lower_triangle = correlation_matrix.where(
                    np.tril(np.ones(correlation_matrix.shape), k=-1).astype(bool)
                )
                
                pairs = []
                values = []
                for i in range(len(abundance_columns)):
                    for j in range(i):
                        pair_name = f"{abundance_columns[j]} vs {abundance_columns[i]}"
                        pairs.append(pair_name)
                        values.append(round(lower_triangle.iloc[i,j], 3))
                
                within_group_correlations[grouping_variable] = pd.Series(values)

        # Create Excel for within-group correlations
        buffer = io.BytesIO()
        with pd.ExcelWriter(buffer, engine='openpyxl') as writer:
            if within_group_correlations:
                combined_correlation_df = pd.concat(within_group_correlations, axis=1)
                
                # Calculate statistics
                min_values = combined_correlation_df.min().round(3)
                max_values = combined_correlation_df.max().round(3)
                mean_values = combined_correlation_df.mean().round(3)
                
                summary_df = pd.DataFrame({
                    'Min': min_values,
                    'Max': max_values,
                    'Average': mean_values
                }).T
                
                combined_with_summary = pd.concat([combined_correlation_df, summary_df], axis=0)
                combined_with_summary.to_excel(writer, sheet_name='Summary')

                # Write individual group sheets
                for key, value in group_data.items():
                    grouping_variable = value['grouping_variable']
                    if grouping_variable in within_group_correlations:
                        values = within_group_correlations[grouping_variable]
                        pairs = []
                        for i in range(len(value['abundance_columns'])):
                            for j in range(i):
                                pairs.append(f"{value['abundance_columns'][j]} vs {value['abundance_columns'][i]}")
                        
                        group_df = pd.DataFrame({
                            'Pair': pairs,
                            'Correlation': values
                        })
                        group_df.to_excel(writer, sheet_name=grouping_variable, index=False)

        # Generate filenames and links
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        correlation_type = self.correlation_plotter.correlation_type.value.lower()
        transform_type = "log10" if self.correlation_plotter.log_transform.value else "raw"

        # Current plot link
        if (hasattr(self.correlation_plotter, 'group1_dropdown') and 
            hasattr(self.correlation_plotter, 'group2_dropdown') and
            self.correlation_plotter.group1_dropdown.value and 
            self.correlation_plotter.group2_dropdown.value):
            
            fig = self.correlation_plotter.create_correlation_plot(
                self.correlation_plotter.group1_dropdown.value,
                self.correlation_plotter.group2_dropdown.value
            )
            if fig is not None:
                plot_filename = f"correlation_plot_{correlation_type}_{transform_type}_{timestamp}.html"
                plot_link = self.generate_download_link(fig.to_html(), plot_filename, 'text/html')
            else:
                plot_link = '<span style="color: red;">No data available for selected groups</span>'
        else:
            plot_link = '<span style="color: red;">Please select groups for comparison</span>'

        # Generate other links
        csv_filename = f"cross_group_correlations_{correlation_type}_{transform_type}_{timestamp}.csv"
        excel_filename = f"within_group_correlations_{correlation_type}_{transform_type}_{timestamp}.xlsx"
        
        csv_link = self.generate_download_link(cross_group_correlations, csv_filename)
        excel_link = self.generate_download_link(
            buffer.getvalue(), 
            excel_filename,
            'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
        )

        # Display the export sections
        with self.export_output:
            self.export_output.clear_output(wait=True)
            html_content = f"""
            {self.style}
            <div class="export-section">
                <h3>Current Plot Export</h3>
                <div class="export-description">
                    Download current correlation plot as interactive HTML file
                </div>
                {plot_link}
            </div>

            <div class="export-section">
                <h3>Cross-Group Correlations</h3>
                <div class="export-description">
                    Download correlations between different sample groups (CSV format)
                </div>
                {csv_link}
            </div>
            
            <div class="export-section">
                <h3>Within-Group Correlations</h3>
                <div class="export-description">
                    Download detailed correlation analysis for each group with summary statistics (Excel format)
                </div>
                {excel_link}
            </div>
            """
            
            display(HTML(html_content))
        

    def _on_data_change(self, change):
        """Handle changes in data uploads"""
        self._update_export_visibility()
        self.generate_export_links()

    def _update_export_visibility(self):
        """Update visibility of export options based on data availability"""
        has_data = (self.data_transformer.merged_df is not None and 
                   self.data_transformer.group_data is not None)
        
        # Update batch export section visibility
        button_id = f'batch_export_button_{id(self)}'
        batch_export_html = f"""
            {self.style}
            <div class="export-section {'hidden' if not has_data else ''}">
                <h3>Batch Export Plots</h3>
                <div class="export-description">
                    Export all possible group combinations as PNG plots in a ZIP file
                </div>
                <a href="#" class="download-link" onclick="document.getElementsByClassName('{button_id}')[0].click(); return false;">
                    Export All Combinations as ZIP
                </a>
            </div>
        """
        self.batch_export_container.value = batch_export_html

    def display(self):
        """Display the export section"""
        # Add an ID to the button
        button_id = f'batch_export_button_{id(self)}'
        self.batch_export_button._dom_classes = ['batch_export_button', button_id]
        
        # Hide the actual button but keep it functional
        self.batch_export_button.layout.display = 'none'
        self.batch_export_button.layout.visibility = 'hidden'
        
        display(HTML(self.style + "<h2><u>Export:</u></h2>"))
        display(self.batch_export_button)
        display(self.batch_export_container)
        display(self.export_output)
        
        # Initialize export options visibility
        self._update_export_visibility()
        self.generate_export_links()

In [5]:

# Create components
data_transformer = DataTransformation()
data_transformer.setup_data_loading_ui()  # Set up UI first

# Now create the correlation plotter
correlation_plotter = CorrelationPlotter(data_transformer)
correlation_plotter.setup_data_transformer_observers()  # Set up observers after UI is ready

# Create export handler
export_handler = CorrelationExportHandler(data_transformer, correlation_plotter)

# Display components
correlation_plotter.display()
export_handler.display()

# Create observer function for export handler
def update_export_links(change):
    export_handler.generate_export_links()

# Add observers for all relevant features
correlation_plotter.group1_dropdown.observe(update_export_links, names='value')
correlation_plotter.group2_dropdown.observe(update_export_links, names='value')
correlation_plotter.correlation_type.observe(update_export_links, names='value')
correlation_plotter.log_transform.observe(update_export_links, names='value')
correlation_plotter.function1_widget.observe(update_export_links, names='value')
correlation_plotter.function2_widget.observe(update_export_links, names='value')
correlation_plotter.function1_color.observe(update_export_links, names='value')
correlation_plotter.function2_color.observe(update_export_links, names='value')

# Add data upload observers
data_transformer.merged_uploader.observe(update_export_links, names='value')
data_transformer.group_uploader.observe(update_export_links, names='value')


GridBox(children=(VBox(children=(HTML(value='<h4>Upload Data Files:</h4>'), HBox(children=(FileUpload(value=()…

VBox(children=(HTML(value='<h4>Plot Controls:</h4>'), Dropdown(description='Group 1:', layout=Layout(width='30…

Button(description='Export All Combinations as ZIP', layout=Layout(display='none', height='40px', margin='0', …

HTML(value='')

Output()