In [1]:
import pandas as pd
import numpy as np
from datetime import datetime
import json, io, base64, re, os
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from itertools import combinations


In [2]:
class DataTransformation:
    def __init__(self):
        self.merged_df = None
        self.group_data = None
        self.output_area = None
        self.merged_uploader = None
        self.group_uploader = None
        self.reset_button = None
        
    def create_download_link(self, file_path, label):
        """Create a download link for a file."""
        if os.path.exists(file_path):
            # Read file content and encode it as base64
            with open(file_path, 'rb') as f:
                content = f.read()
            b64_content = base64.b64encode(content).decode('utf-8')

            # Generate the download link HTML
            return widgets.HTML(f"""
                <a download="{os.path.basename(file_path)}" 
                   href="data:application/octet-stream;base64,{b64_content}" 
                   style="color: #0366d6; text-decoration: none; margin-left: 20px; font-size: 14px;">
                    {label}
                </a>
            """)
        else:
            # Show an error message if the file does not exist
            return widgets.HTML(f"""
                <span style="color: red; margin-left: 20px; font-size: 14px;">
                    File "{file_path}" not found!
                </span>
            """)    
            
    def setup_data_loading_ui(self):
        """Initialize and display the data loading UI"""
        # Create file upload widgets
        self.merged_uploader = widgets.FileUpload(
            accept='.csv,.txt,.tsv,.xlsx',
            multiple=False,
            description='Upload Merged Data File',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )
        
        self.group_uploader = widgets.FileUpload(
            accept='.json',
            multiple=False,
            description='Upload Group Definition',
            layout=widgets.Layout(width='300px'),
            style={'description_width': 'initial'}
        )

        self.output_area = widgets.Output()

       # Create individual upload boxes with example links
        merged_box = widgets.HBox([
            self.merged_uploader,
            self.create_download_link("example_merged_dataframe.csv", "Example")
        ], layout=widgets.Layout(align_items='center'))

        group_box = widgets.HBox([
            self.group_uploader,
            self.create_download_link("example_group_definition.json", "Example")
            
        ], layout=widgets.Layout(align_items='center'))
        
        # Create upload section
        upload_widgets = widgets.VBox([
            widgets.HTML("<h3>Upload Data Files</h3>"),
            merged_box,
            widgets.HTML("<h3>Upload Group Definition</h3>"),
            group_box,
            self.output_area
        ])
        
        # Create status area
        self.status_area = widgets.Output()
        
        # Create grid layout
        grid = widgets.GridBox(
            [upload_widgets, self.status_area],
            layout=widgets.Layout(
                grid_template_columns='auto auto',
                grid_gap='20px',
                width='900px'
            )
        )
        
        # Register observers
        self.merged_uploader.observe(self._on_merged_upload_change, names='value')
        self.group_uploader.observe(self._on_group_upload_change, names='value')

        # Display the grid
        display(grid)

    def _on_merged_upload_change(self, change):
        if change['type'] == 'change' and change['name'] == 'value':
            with self.output_area:
                self.output_area.clear_output()
                uploaded_files = change.get('new', ())
                if uploaded_files:
                    # Get the first uploaded file
                    file_data = uploaded_files[0]  # Access first element of tuple
                    self.merged_df = self._load_merged_data(file_data)
                    if self.merged_df is not None:
                        display(HTML(
                            f'<b style="color:green;">Merged data imported: '
                            f'{self.merged_df.shape[0]} rows, {self.merged_df.shape[1]} columns</b>'
                        ))

    def _on_group_upload_change(self, change):
        if change['type'] == 'change' and change['name'] == 'value':
            with self.output_area:
                self.output_area.clear_output()
                uploaded_files = change.get('new', ())
                if uploaded_files:
                    # Get the first uploaded file
                    file_data = uploaded_files[0]  # Access first element of tuple
                    try:
                        content = bytes(file_data.content).decode('utf-8')
                        group_data = json.loads(content)
                        self.group_data = self._process_group_data(group_data)
                        display(HTML(
                            f'<b style="color:green;">Group definition imported successfully with {len(self.group_data)} groups.</b><br>' 
                        ))
                    except Exception as e:
                        display(HTML(f'<b style="color:red;">Error loading group definition: {str(e)}</b>'))

    def _load_merged_data(self, file_data):
        try:
            content = bytes(file_data.content)
            filename = file_data.name
            extension = filename.split('.')[-1].lower()
            
            file_stream = io.BytesIO(content)
            
            if extension == 'csv':
                df = pd.read_csv(file_stream)
            elif extension in ['txt', 'tsv']:
                df = pd.read_csv(file_stream, delimiter='\t')
            elif extension == 'xlsx':
                df = pd.read_excel(file_stream)
            else:
                raise ValueError("Unsupported file format")
            
            # Validate required columns
            required_columns = ['Master Protein Accessions', 'Positions in Proteins']
            missing = set(required_columns) - set(df.columns)
            if missing:
                raise ValueError(f"Missing required columns: {', '.join(missing)}")
                
            return df
            
        except Exception as e:
            display(HTML(f'<b style="color:red;">Error loading data: {str(e)}</b>'))
            return None

    def _process_group_data(self, json_data):
        """Process and validate the group data structure"""
        try:
            processed_data = {}
            for group_id, group_info in json_data.items():
                # Validate required fields
                if 'grouping_variable' not in group_info:
                    raise ValueError(f"Group {group_id} missing grouping_variable")
                if 'abundance_columns' not in group_info:
                    raise ValueError(f"Group {group_id} missing abundance_columns")

                # Create standardized group entry
                processed_data[group_id] = {
                    'grouping_variable': group_info['grouping_variable'],
                    'abundance_columns': group_info['abundance_columns']
                }

            return processed_data
        except Exception as e:
            raise ValueError(f"Error processing group data: {str(e)}")


In [3]:
class CorrelationPlotter:
    def __init__(self, data_transformer):
        self.data_transformer = data_transformer
        self.plot_output = widgets.Output()
        self.info_output = widgets.Output()  # Add info output for displaying peptide counts
        self.export_output = widgets.Output()

        # Initialize function options
        self.function_options = self._get_function_options()
        
        # Create widgets for correlation analysis
        self.group1_dropdown = widgets.Dropdown(
            description='Group 1:',
            options=[],
            layout=widgets.Layout(width='400px')
        )
        
        self.group2_dropdown = widgets.Dropdown(
            description='Group 2:',
            options=[],
            layout=widgets.Layout(width='400px')
        )
        
        self.function_widget = widgets.Dropdown(
            options=self.function_options,
            description='Function:',
            value='All',
            layout=widgets.Layout(width='400px')
        )

        # Add label customization widgets
        self.xlabel_widget = widgets.Text(
            description='X Label:',
            placeholder='Enter x-axis label',
            layout=widgets.Layout(width='400px')
        )
        
        self.ylabel_widget = widgets.Text(
            description='Y Label:',
            placeholder='Enter y-axis label',
            layout=widgets.Layout(width='400px')
        )

        self.plot_color = widgets.ColorPicker(
            concise=False,
            description='Plot Color:',
            value='#0072C6',
            layout=widgets.Layout(width='400px')
        )

        # Create layout with all widgets
        self.widget_box = widgets.VBox([
            widgets.HTML("<h3>Correlation Analysis</h3>"),
            self.group1_dropdown,
            self.group2_dropdown,
            self.function_widget,
            self.xlabel_widget,
            self.ylabel_widget,
            self.plot_color,
            self.info_output,
            self.export_output,
            self.plot_output
        ])

        # Watch for changes
        self.data_transformer.merged_uploader.observe(self._on_data_change, names='value')
        self.data_transformer.group_uploader.observe(self._on_data_change, names='value')
        
        # Add observers for all widgets
        self.group1_dropdown.observe(self._on_widget_change, names='value')
        self.group2_dropdown.observe(self._on_widget_change, names='value')
        self.function_widget.observe(self._on_widget_change, names='value')
        self.xlabel_widget.observe(self._on_widget_change, names='value')
        self.ylabel_widget.observe(self._on_widget_change, names='value')
        self.plot_color.observe(self._on_widget_change, names='value')

    def _get_function_options(self):
        """Extract and process function options from the data"""
        if self.data_transformer.merged_df is not None and 'function' in self.data_transformer.merged_df.columns:
            # Extract unique function components and count occurrences
            function_counts = {}
            
            for func in self.data_transformer.merged_df['function'].dropna():
                # Split by both comma and semicolon
                components = [f.strip() for f in func.replace(';', ',').split(',')]
                for component in components:
                    function_counts[component] = function_counts.get(component, 0) + 1
            
            # Filter for functions with more than 1 occurrence
            valid_functions = [func for func, count in function_counts.items() if count > 1]
            
            # Sort and add 'All' and 'Unknown' options
            return ['All', 'Unknown'] + sorted(valid_functions)
        return ['All', 'Unknown']

    def create_correlation_plot(self, group1, group2, function_filter, plot_color):
        """Generate correlation plot with function filtering and custom labels"""
        if self.data_transformer.merged_df is None or self.data_transformer.group_data is None:
            return None

        df = self.data_transformer.merged_df.copy()
        
        col1 = f"Avg_{group1}"
        col2 = f"Avg_{group2}"
        
        if col1 not in df.columns or col2 not in df.columns:
            return None
        
        # Filter data based on selected function
        if function_filter == 'Unknown':
            df = df[df['function'].isna()]
        elif function_filter != 'All':
            df = df[df['function'].fillna('').str.contains(
                function_filter, 
                case=False, 
                na=False
            )]
        
        # Filter for positive values
        filtered_df = df[(df[col1] > 0) & (df[col2] > 0)].copy()
        
        if len(filtered_df) == 0:
            return None
        
        # Update info display
        with self.info_output:
            self.info_output.clear_output(wait=True)
            print(f"Displaying {len(filtered_df)} peptides")
            if function_filter != 'All':
                print(f"Function filter: {function_filter}")
        
        # Calculate log values
        log_x = np.log10(filtered_df[col1])
        log_y = np.log10(filtered_df[col2])
        
        # Calculate correlation
        if len(filtered_df) > 1:
            corr, _ = pearsonr(log_x, log_y)
        else:
            corr = float('nan')

        # Create plot
        fig = go.Figure()
        
        # Add scatter points with hover data
        fig.add_trace(go.Scatter(
            x=log_x,
            y=log_y,
            mode='markers',
            marker=dict(color=plot_color),
            hovertemplate=(
                '<b>Peptide ID:</b> %{customdata[0]}<br>' +
                '<b>Function:</b> %{customdata[1]}<br>' +
                f'<b>{group1}:</b> %{{x:.2e}}<br>' +
                f'<b>{group2}:</b> %{{y:.2e}}<br>' +
                '<extra></extra>'
            ),
            customdata=np.column_stack((
                filtered_df['unique ID'],
                filtered_df['function'].fillna('Unknown')
            ))
        ))

        # Add trendline if we have enough points
        if len(filtered_df) > 1:
            z = np.polyfit(log_x, log_y, 1)
            x_range = np.linspace(log_x.min(), log_x.max(), 100)
            fig.add_trace(go.Scatter(
                x=x_range,
                y=np.poly1d(z)(x_range),
                mode='lines',
                line=dict(color=plot_color, dash='dash'),
                showlegend=False,
                hovertemplate='<extra></extra>'
            ))

        # Get custom labels or use defaults
        xlabel = self.xlabel_widget.value or f'Log10({group1})'
        ylabel = self.ylabel_widget.value or f'Log10({group2})'

        # Update layout
        fig.update_layout(
            title=dict(
                text=f'Correlation Plot (r = {corr:.3f})',
                x=0.5,
                xanchor='center'
            ),
            xaxis_title=xlabel,
            yaxis_title=ylabel,
            width=800,
            height=800,
            template='simple_white',
            showlegend=False,
            hoverlabel=dict(
                bgcolor="white",
                font_size=14,
                font_family="Arial"
            )
        )

        # Make aspect ratio equal
        fig.update_layout(yaxis=dict(scaleanchor="x", scaleratio=1))

        return fig

    def _on_widget_change(self, change):
        """Handle any widget value changes"""
        if all([self.group1_dropdown.value, self.group2_dropdown.value]):
            self._update_plot()

    def _on_data_change(self, change):
        """Update widgets when data changes"""
        if self.data_transformer.merged_df is not None and self.data_transformer.group_data is not None:
            # Update group dropdowns
            group_vars = [info['grouping_variable'] for info in self.data_transformer.group_data.values()]
            self.group1_dropdown.options = group_vars
            self.group2_dropdown.options = group_vars
            
            # Update function options
            self.function_widget.options = self._get_function_options()

    def _update_plot(self):
        """Update plot based on current widget values"""
        with self.plot_output:
            self.plot_output.clear_output(wait=True)
            
            if self.group1_dropdown.value == self.group2_dropdown.value:
                display(HTML('<b style="color:red;">Please select different groups for comparison.</b>'))
                return
                
            fig = self.create_correlation_plot(
                self.group1_dropdown.value,
                self.group2_dropdown.value,
                self.function_widget.value,
                self.plot_color.value
            )
            
            if fig is not None:
                fig.show(config={
                    'displayModeBar': True,
                    'scrollZoom': True,
                    'modeBarButtonsToRemove': ['select2d', 'lasso2d']
                })
            else:
                display(HTML('<b style="color:red;">No data available for the selected combination.</b>'))

    def display(self):
        """Display the correlation analysis interface"""
        display(self.widget_box)



    

In [4]:
class CorrelationExportHandler:
    def __init__(self, data_transformer):
        self.data_transformer = data_transformer
        self.export_output = widgets.Output()

    def display(self):
        """Display the export section"""
        display(HTML("<h2>Export:</h2>"))

        display(self.export_output)

    def generate_export_links(self):
        """Generate and display export links for correlation data"""
        if self.data_transformer.merged_df is None or self.data_transformer.group_data is None:
            return

        df = self.data_transformer.merged_df.copy()
        group_data = self.data_transformer.group_data.copy()        
        
        # Define CSS style
        style = """
            <style>
            .download-link {
                background-color: #4CAF50;
                border: none;
                color: white;
                padding: 10px 20px;
                text-align: center;
                text-decoration: none;
                display: inline-block;
                font-size: 14px;
                margin: 4px 2px;
                cursor: pointer;
                border-radius: 4px;
            }
            .download-link:hover {
                background-color: #45a049;
            }
            .export-section {
                margin-bottom: 20px;
                padding: 15px;
                border-radius: 5px;
                background-color: #f8f9fa;
            }
            .export-description {
                color: #666;
                margin: 5px 0 15px 0;
                font-style: italic;
            }
            </style>
        """

        def generate_download_link(content, filename, filetype='text/csv'):
            """Generate a download link for any content"""
            if isinstance(content, pd.DataFrame):
                if filetype == 'text/csv':
                    content = content.to_csv(index=False)
                else:
                    content = content.to_csv(index=True)
            if isinstance(content, str):
                content = content.encode()
            b64 = base64.b64encode(content).decode()
            return f"""
                <a download="{filename}" href="data:{filetype};base64,{b64}" class="download-link" 
                   title="Click to download">
                    Download {filename}
                </a>
            """

        # Calculate cross-group correlations
        def calculate_all_correlations():
            """Calculate correlations between all group pairs"""
            correlation_results = []
            Avg_columns = {
                group_info['grouping_variable']: f"Avg_{group_info['grouping_variable']}"
                for group_info in group_data.values()
                if f"Avg_{group_info['grouping_variable']}" in df.columns
            }
            
            for (group1, col1), (group2, col2) in combinations(Avg_columns.items(), 2):
                mask = (df[col1] > 0) & (df[col2] > 0)
                if mask.sum() > 1:
                    log_col1 = np.log10(df.loc[mask, col1])
                    log_col2 = np.log10(df.loc[mask, col2])
                    correlation = log_col1.corr(log_col2)
                    correlation = round(correlation, 3)
                    correlation_results.append((group1, group2, correlation))
            
            return pd.DataFrame(correlation_results, columns=['Group 1', 'Group 2', 'Correlation'])

        # Calculate within-group correlations
        def calculate_group_correlations():
            """Calculate within-group correlations and statistics"""
            correlation_dfs = {}
            
            for key, value in group_data.items():
                grouping_variable = value['grouping_variable']
                abundance_columns = value['abundance_columns']
                
                data = df[abundance_columns].copy()
                data = data[data.gt(0).all(axis=1)]
                
                if len(data) > 1:
                    data = np.log10(data)
                    correlation_matrix = data.corr(method='pearson')
                    lower_triangle = correlation_matrix.where(
                        np.tril(np.ones(correlation_matrix.shape), k=-1).astype(bool)
                    )
                    
                    pairs = []
                    values = []
                    for i in range(len(abundance_columns)):
                        for j in range(i):
                            pair_name = f"{abundance_columns[j]} vs {abundance_columns[i]}"
                            pairs.append(pair_name)
                            values.append(round(lower_triangle.iloc[i,j], 3))
                    
                    correlation_dfs[grouping_variable] = pd.Series(values)
            
            return correlation_dfs

        # Calculate correlations
        cross_group_correlations = calculate_all_correlations()
        within_group_correlations = calculate_group_correlations()
        
        # Create Excel for within-group correlations
        buffer = io.BytesIO()
        with pd.ExcelWriter(buffer, engine='openpyxl') as writer:
            # Combine correlations into DataFrame with numeric index
            combined_correlation_df = pd.concat(within_group_correlations, axis=1)
            
            # Calculate statistics
            min_values = combined_correlation_df.min().round(3)
            max_values = combined_correlation_df.max().round(3)
            mean_values = combined_correlation_df.mean().round(3)
            
            # Create summary DataFrame with statistics
            summary_df = pd.DataFrame({
                'Min': min_values,
                'Max': max_values,
                'Average': mean_values
            }).T
            
            # Combine correlations and summary
            combined_correlation_with_summary = pd.concat([combined_correlation_df, summary_df], axis=0)
            
            # Write summary sheet with numeric index
            combined_correlation_with_summary.to_excel(writer, sheet_name='Summary')
            
            # Write individual group sheets
            for key, value in group_data.items():
                grouping_variable = value['grouping_variable']
                abundance_columns = value['abundance_columns']
                
                if grouping_variable in within_group_correlations:
                    values = within_group_correlations[grouping_variable]
                    pairs = []
                    for i in range(len(abundance_columns)):
                        for j in range(i):
                            pair_name = f"{abundance_columns[j]} vs {abundance_columns[i]}"
                            pairs.append(pair_name)
                    
                    group_df = pd.DataFrame({
                        'Pair': pairs,
                        'Correlation': values
                    })
                    group_df.to_excel(writer, sheet_name=grouping_variable, index=False)
            
            # Remove borders and adjust column widths
            for sheet_name in writer.sheets:
                worksheet = writer.sheets[sheet_name]
                for row in worksheet.iter_rows():
                    for cell in row:
                        cell.border = None
                for column_cells in worksheet.columns:
                    max_length = max(len(str(cell.value)) if cell.value is not None else 0 
                                   for cell in column_cells)
                    worksheet.column_dimensions[column_cells[0].column_letter].width = max_length + 2

        # Generate download sections
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        
        # Cross-group correlations CSV
        csv_filename = f"cross_group_correlations_{timestamp}.csv"
        csv_link = generate_download_link(cross_group_correlations, csv_filename)
        
        # Within-group correlations Excel
        excel_filename = f"within_group_correlations_{timestamp}.xlsx"
        excel_link = generate_download_link(
            buffer.getvalue(), 
            excel_filename, 
            'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
        )
        
        with self.export_output:
            self.export_output.clear_output(wait=True)
            html_content = f"""
            {style}
            <div class="export-section">
                <h3>Cross-Group Correlations</h3>
                <div class="export-description">
                    Download correlations between different sample groups (CSV format)
                </div>
                {csv_link}
            </div>
            
            <div class="export-section">
                <h3>Within-Group Correlations</h3>
                <div class="export-description">
                    Download detailed correlation analysis for each group with summary statistics (Excel format)
                </div>
                {excel_link}
            </div>
            """
            
            display(HTML(html_content))

In [5]:
# Initialize the interface
data_transformer = DataTransformation()
data_transformer.setup_data_loading_ui()

# Create correlation plotter
correlation_plotter = CorrelationPlotter(data_transformer)
correlation_plotter.display()

# Create export handler
export_handler = CorrelationExportHandler(data_transformer)
export_handler.display()

# Optional: If you want to generate export links when data is uploaded
data_transformer.merged_uploader.observe(
    lambda change: export_handler.generate_export_links(), 
    names='value'
)
data_transformer.group_uploader.observe(
    lambda change: export_handler.generate_export_links(), 
    names='value'
)

GridBox(children=(VBox(children=(HTML(value='<h3>Upload Data Files</h3>'), HBox(children=(FileUpload(value=(),…

VBox(children=(HTML(value='<h3>Correlation Analysis</h3>'), Dropdown(description='Group 1:', layout=Layout(wid…

Output()