In [1]:
# Cell 1: Import required libraries
import pandas as pd
import csv, json, re, os, sys, math, base64, io
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime
import ipywidgets as widgets
from IPython.display import display, HTML
import seaborn as sns
from scipy.stats import pearsonr
from itertools import combinations

In [2]:
# Cell 1: Create the DataManager class and update function
class DataManager:
    def __init__(self):
        self.merged_df = None
        self.group_data = None
        
    def set_merged_df(self, df):
        self.merged_df = df
        
    def set_group_data(self, data):
        self.group_data = data
        
    def has_merged_data(self):
        return self.merged_df is not None and not self.merged_df.empty
        
    def has_group_data(self):
        return self.group_data is not None and len(self.group_data) > 0
        
    def has_all_data(self):
        return self.has_merged_data() and self.has_group_data()

# Create the data manager instance
data_manager = DataManager()

In [3]:
def process_group_data(json_data):
    """Process and validate the group data, setting defaults if needed"""
    try:
        # If group 2 doesn't have a grouping_variable, set it to 'New_Formula'
        if '2' in json_data and 'grouping_variable' not in json_data['2']:
            json_data['2']['grouping_variable'] = 'New_Formula'
            
        # Validate that each group has the required fields
        for group_id, group_info in json_data.items():
            if 'grouping_variable' not in group_info:
                raise ValueError(f"Group {group_id} missing grouping_variable")
            if 'abundance_columns' not in group_info:
                raise ValueError(f"Group {group_id} missing abundance_columns")
                
        return json_data
    except Exception as e:
        raise ValueError(f"Error processing group data: {str(e)}")

def load_data_files():
    """
    Create file upload widgets for merged data and group definition files.
    Returns widgets and registers callbacks for data loading.
    """
    # Create file upload widgets
    merged_uploader = widgets.FileUpload(
        accept='.csv,.txt,.tsv,.xlsx',
        multiple=False,
        description='Upload Merged Data File',
        layout=widgets.Layout(width='300px'),
        style={'description_width': 'initial'}
    )
    
    group_uploader = widgets.FileUpload(
        accept='.json',
        multiple=False,
        description='Upload Group Definition',
        layout=widgets.Layout(width='300px'),
        style={'description_width': 'initial'}
    )

    output_area = widgets.Output()
    
    def on_merged_upload_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            with output_area:
                output_area.clear_output()
                uploaded_files = change['new']
                if len(uploaded_files) > 0:
                    file_data = uploaded_files[0]
                    merged_df = load_merged_data(file_data)
                    if merged_df is not None:
                        data_manager.set_merged_df(merged_df)
                        display(HTML(f'<b style="color:green;">Merged data imported with {merged_df.shape[0]} rows and {merged_df.shape[1]} columns.</b>'))
                        update_visualizations()
    
    def on_group_upload_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            with output_area:
                output_area.clear_output()
                uploaded_files = change['new']
                if len(uploaded_files) > 0:
                    file_data = uploaded_files[0]
                    try:
                        # Convert memoryview to bytes then to string
                        content = bytes(file_data.content).decode('utf-8')
                        group_data_raw = json.loads(content)
                        
                        # Process and validate group data
                        processed_group_data = process_group_data(group_data_raw)
                        data_manager.set_group_data(processed_group_data)
                        
                        display(HTML(f'<b style="color:green;">Group definition file imported successfully with {len(processed_group_data)} groups.</b>'))
                        print("Loaded groups:", [g['grouping_variable'] for g in processed_group_data.values()])
                        update_visualizations()
                    except Exception as e:
                        display(HTML(f'<b style="color:red;">Error loading group definition file: {str(e)}</b>'))
                        print(f"Detailed error: {str(e)}")
                        print(f"Type of content: {type(file_data.content)}")
    
    merged_uploader.observe(on_merged_upload_change, names='value')
    group_uploader.observe(on_group_upload_change, names='value')
    
    return merged_uploader, group_uploader, output_area

def load_merged_data(file_obj):
    """
    Load and validate the merged data file.
    """
    try:
        content = bytes(file_obj.content)  # Convert memoryview to bytes
        filename = file_obj.name
        extension = filename.split('.')[-1].lower()
        
        file_stream = io.BytesIO(content)
        
        if extension == 'csv':
            df = pd.read_csv(file_stream)
        elif extension in ['txt', 'tsv']:
            df = pd.read_csv(file_stream, delimiter='\t')
        elif extension == 'xlsx':
            df = pd.read_excel(file_stream)
        else:
            raise ValueError("Unsupported file format.")
        
        # Validate required columns
        required_columns = ['Master Protein Accessions', 'Positions in Proteins']
        if not set(required_columns).issubset(df.columns):
            missing = set(required_columns) - set(df.columns)
            raise ValueError(f"Missing required columns: {', '.join(missing)}")
            
        return df
        
    except Exception as e:
        display(HTML(f'<b style="color:red;">Error loading data: {str(e)}</b>'))
        return None
        
def setup_correlation_exports():
    """
    Setup correlation analysis exports with styled download buttons
    """
    if not data_manager.has_all_data():
        display(HTML(f'<b style="color:blue;">Please load both data files to enable export functionality.</b>'))
        return
     
    def generate_download_link(content, filename, filetype='text/csv'):
        """Generate a download link for any content"""
        if isinstance(content, pd.DataFrame):
            if filetype == 'text/csv':
                content = content.to_csv(index=False)  # Remove index only for CSV
            else:
                content = content.to_csv(index=True)  # Keep index for other formats
        if isinstance(content, str):
            content = content.encode()
        b64 = base64.b64encode(content).decode()
        return f"""
            <a download="{filename}" href="data:{filetype};base64,{b64}" class="download-link" 
               title="Click to download">
                Download {filename}
            </a>
        """

    def calculate_all_correlations():
        """Calculate correlations between all group pairs"""
        correlation_results = []
        Avg_columns = {
            group_info['grouping_variable']: f"Avg_{group_info['grouping_variable']}"
            for group_info in group_data.values()
            if f"Avg_{group_info['grouping_variable']}" in df.columns
        }
        
        for (group1, col1), (group2, col2) in combinations(Avg_columns.items(), 2):
            # Filter for positive values and calculate correlation
            mask = (df[col1] > 0) & (df[col2] > 0)
            if mask.sum() > 1:
                log_col1 = np.log10(df.loc[mask, col1])
                log_col2 = np.log10(df.loc[mask, col2])
                correlation = log_col1.corr(log_col2)
                correlation = round(correlation, 3)
                correlation_results.append((group1, group2, correlation))
        
        return pd.DataFrame(correlation_results, columns=['Group 1', 'Group 2', 'Correlation'])

    def calculate_group_correlations():
        """Calculate within-group correlations and statistics"""
        correlation_dfs = {}
        
        for key, value in group_data.items():
            grouping_variable = value['grouping_variable']
            abundance_columns = value['abundance_columns']
            
            data = df[abundance_columns].copy()
            data = data[data.gt(0).all(axis=1)]
            
            if len(data) > 1:
                data = np.log10(data)
                correlation_matrix = data.corr(method='pearson')
                lower_triangle = correlation_matrix.where(
                    np.tril(np.ones(correlation_matrix.shape), k=-1).astype(bool)
                )
                
                pairs = []
                values = []
                for i in range(len(abundance_columns)):
                    for j in range(i):
                        pair_name = f"{abundance_columns[j]} vs {abundance_columns[i]}"
                        pairs.append(pair_name)
                        values.append(round(lower_triangle.iloc[i,j], 3))
                
                correlation_dfs[grouping_variable] = pd.Series(values)
        
        return correlation_dfs

    # Calculate correlations
    cross_group_correlations = calculate_all_correlations()
    within_group_correlations = calculate_group_correlations()
    
    # Create Excel for within-group correlations
    buffer = io.BytesIO()
    with pd.ExcelWriter(buffer, engine='openpyxl') as writer:
        # Combine correlations into DataFrame with numeric index
        combined_correlation_df = pd.concat(within_group_correlations, axis=1)
        
        # Calculate statistics
        min_values = combined_correlation_df.min().round(3)
        max_values = combined_correlation_df.max().round(3)
        mean_values = combined_correlation_df.mean().round(3)
        
        # Create summary DataFrame with statistics
        summary_df = pd.DataFrame({
            'Min': min_values,
            'Max': max_values,
            'Average': mean_values
        }).T
        
        # Combine correlations and summary, keeping numeric index
        combined_correlation_with_summary = pd.concat([combined_correlation_df, summary_df], axis=0)
        
        # Write summary sheet with numeric index
        combined_correlation_with_summary.to_excel(writer, sheet_name='Summary')
        
        # Write individual group sheets
        for key, value in group_data.items():
            grouping_variable = value['grouping_variable']
            abundance_columns = value['abundance_columns']
            
            if grouping_variable in within_group_correlations:
                values = within_group_correlations[grouping_variable]
                pairs = []
                for i in range(len(abundance_columns)):
                    for j in range(i):
                        pair_name = f"{abundance_columns[j]} vs {abundance_columns[i]}"
                        pairs.append(pair_name)
                
                group_df = pd.DataFrame({
                    'Pair': pairs,
                    'Correlation': values
                })
                group_df.to_excel(writer, sheet_name=grouping_variable, index=False)
        
        # Remove borders and adjust column widths
        for sheet_name in writer.sheets:
            worksheet = writer.sheets[sheet_name]
            for row in worksheet.iter_rows():
                for cell in row:
                    cell.border = None
            for column_cells in worksheet.columns:
                max_length = max(len(str(cell.value)) if cell.value is not None else 0 
                               for cell in column_cells)
                worksheet.column_dimensions[column_cells[0].column_letter].width = max_length + 2

    # Generate download sections
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # Cross-group correlations CSV (without index)
    csv_filename = f"cross_group_correlations_{timestamp}.csv"
    csv_link = generate_download_link(cross_group_correlations, csv_filename)
    
    # Within-group correlations Excel (with index)
    excel_filename = f"within_group_correlations_{timestamp}.xlsx"
    excel_link = generate_download_link(
        buffer.getvalue(), 
        excel_filename, 
        'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
    )
    
    html_content = f"""
    {style}
    <div class="export-section">
        <h3>Cross-Group Correlations</h3>
        <div class="export-description">
            Download correlations between different sample groups (CSV format)
        </div>
        {csv_link}
    </div>
    
    <div class="export-section">
        <h3>Within-Group Correlations</h3>
        <div class="export-description">
            Download detailed correlation analysis for each group with summary statistics (Excel format)
        </div>
        {excel_link}
    </div>
    """
    
    display(HTML(html_content))


def plot_interactive_correlation():
    """
    Creates an interactive correlation plot with simplified function filtering.
    """
    if not data_manager.has_all_data():
        display(HTML(f'<b style="color:blue;">Please load both data files to begin visualization.</b>'))
        return
        
    df = data_manager.merged_df.copy()
    group_data = data_manager.group_data
    
    if group_data is not None:
        group_names = [group_info['grouping_variable'] for group_info in group_data.values()]

    # Initialize function options
    function_options = ['All', 'Unknown']
    if not df.empty and 'function' in df.columns:
        all_functions = df['function'].dropna().unique()
        function_counts = {}
        
        # Count occurrences of each function component
        for func in df['function'].dropna():
            components = [f.strip() for f in func.replace(';', ',').split(',')]
            for component in components:
                function_counts[component] = function_counts.get(component, 0) + 1
        
        # Filter for functions with more than 1 occurrence
        valid_functions = [func for func, count in function_counts.items() if count > 1]
        function_options.extend(sorted(valid_functions))

    # Create widgets
    # Create plotting widgets
    group1_widget = widgets.Dropdown(
        options=group_names,
        description='Group 1:',
        layout=widgets.Layout(width='400px')
    )
    
    group2_widget = widgets.Dropdown(
        options=group_names,
        description='Group 2:',
        layout=widgets.Layout(width='400px')
    )
    
    
    function_widget = widgets.Dropdown(
        options=function_options,
        description='Function:',
        value='All',
        layout=widgets.Layout(width='400px')
    )
    
    color_widget = widgets.ColorPicker(
        concise=False,
        description='Plot Color:',
        value='#0072C6',
        layout=widgets.Layout(width='400px')
    )
    
    xlabel_widget = widgets.Text(
        description='X Label:',
        placeholder='Enter x-axis label',
        layout=widgets.Layout(width='400px')
    )
    
    ylabel_widget = widgets.Text(
        description='Y Label:',
        placeholder='Enter y-axis label',
        layout=widgets.Layout(width='400px')
    )

    # Output widgets
    info_output = widgets.Output()
    plot_output = widgets.Output()


    def create_plot(change=None):
        group1 = group1_widget.value
        group2 = group2_widget.value
        function_filter = function_widget.value
        color = color_widget.value
        xlabel = xlabel_widget.value
        ylabel = ylabel_widget.value
        
        with plot_output:
            plot_output.clear_output(wait=True)
            
            if group1 == group2:
                display(HTML(f'<b style="color:red;">Error: Please select different groups for comparison.</b>'))
                return
                
            col1 = f"Avg_{group1}"
            col2 = f"Avg_{group2}"
            
            if col1 not in df.columns or col2 not in df.columns:
                display(HTML(f'<b style="color:red;">Error: Skipping plot for {group1} vs {group2} as one or both columns are missing.</b>'))
                return
            
            # Filter data based on selected function
            working_df = df.copy()
            if function_filter == 'Unknown':
                working_df = working_df[working_df['function'].isna()]
            elif function_filter != 'All':
                # Filter for rows where the function contains the selected component
                working_df = working_df[
                    working_df['function'].fillna('').str.contains(
                        function_filter, 
                        case=False, 
                        na=False
                    )
                ]
            
            # Filter for positive values
            filtered_df = working_df[(working_df[col1] > 0) & (working_df[col2] > 0)].copy()
            
            if len(filtered_df) == 0:
                display(HTML(f'<b style="color:red;">No data points available for the selected combination.</b>'))
                return
            
            # Update info display
            with info_output:
                info_output.clear_output()
                print(f"Displaying {len(filtered_df)} peptides")
                if function_filter != 'All':
                    print(f"Function filter: {function_filter}")
            
            # Calculate log values
            filtered_df['log_x'] = np.log10(filtered_df[col1])
            filtered_df['log_y'] = np.log10(filtered_df[col2])
            
            # Calculate correlation
            if len(filtered_df) > 1:
                corr, _ = pearsonr(filtered_df['log_x'], filtered_df['log_y'])
            else:
                corr = float('nan')
            
            # Create scatter plot
            fig = go.Figure()
            
            # Add scatter points
            fig.add_trace(go.Scatter(
                x=filtered_df['log_x'],
                y=filtered_df['log_y'],
                mode='markers',
                name='Data Points',
                marker=dict(
                    color=color,
                    size=8,
                    line=dict(
                        color=color,
                        width=1
                    )
                ),
                hovertemplate=
                f'<b>Peptide ID:</b> %{{customdata[2]}}<br>' +
                f'<b>Function:</b> %{{customdata[3]}}<br>' +
                f'<b>{group1}:</b> %{{customdata[0]:.2e}}<br>' +
                f'<b>{group2}:</b> %{{customdata[1]:.2e}}<br>' +
                '<extra></extra>',
                customdata=np.column_stack((
                    filtered_df[col1], 
                    filtered_df[col2],
                    filtered_df['unique ID'],
                    filtered_df['function'].fillna('Unknown')
                ))
            ))
            
            # Add trendline if we have enough points
            if len(filtered_df) > 1:
                z = np.polyfit(filtered_df['log_x'], filtered_df['log_y'], 1)
                p = np.poly1d(z)
                x_range = np.linspace(filtered_df['log_x'].min(), filtered_df['log_x'].max(), 100)
                
                fig.add_trace(go.Scatter(
                    x=x_range,
                    y=p(x_range),
                    mode='lines',
                    name='Trendline',
                    line=dict(color=color, dash='dash'),
                    hovertemplate='<extra></extra>'
                ))
            
            # Update layout
            fig.update_layout(
                title=dict(
                    text=f'Correlation Plot (r = {corr:.2f})',
                    x=0.5,
                    xanchor='center'
                ),
                xaxis_title=xlabel if xlabel else group1,
                yaxis_title=ylabel if ylabel else group2,
                xaxis=dict(
                    ticktext=[f'10^{int(i)}' for i in range(int(np.floor(filtered_df['log_x'].min())), 
                                                          int(np.ceil(filtered_df['log_x'].max())) + 1)],
                    tickvals=list(range(int(np.floor(filtered_df['log_x'].min())), 
                                      int(np.ceil(filtered_df['log_x'].max())) + 1)),
                ),
                yaxis=dict(
                    ticktext=[f'10^{int(i)}' for i in range(int(np.floor(filtered_df['log_y'].min())), 
                                                          int(np.ceil(filtered_df['log_y'].max())) + 1)],
                    tickvals=list(range(int(np.floor(filtered_df['log_y'].min())), 
                                      int(np.ceil(filtered_df['log_y'].max())) + 1)),
                ),
                showlegend=False,
                width=800,
                height=800,
                template='simple_white',
                hoverlabel=dict(
                    bgcolor="white",
                    font_size=14,
                    font_family="Arial"
                )
            )
            
            # Make aspect ratio equal
            fig.update_layout(yaxis=dict(scaleanchor="x", scaleratio=1))
            
            # Show plot
            fig.show(config={
                'displayModeBar': True,
                'scrollZoom': True,
                'modeBarButtonsToRemove': ['select2d', 'lasso2d']
            })
    
    # Create widget container with vertical layout
    controls = widgets.VBox([
        widgets.VBox([
            group1_widget,
            group2_widget,
            function_widget,
            color_widget,
            xlabel_widget,
            ylabel_widget
        ]),
        info_output
    ])
    
    # Observe widget changes
    group1_widget.observe(create_plot, names='value')
    group2_widget.observe(create_plot, names='value')
    function_widget.observe(create_plot, names='value')
    color_widget.observe(create_plot, names='value')
    xlabel_widget.observe(create_plot, names='value')
    ylabel_widget.observe(create_plot, names='value')
    
    # Display widgets and create initial plot
    display(controls)
    display(plot_output)
    create_plot()
    

def update_visualizations():
    """
    Update all visualizations when data changes
    """
    plot_interactive_correlation()
    setup_correlation_exports()

# Cell 2: Define the initialization and display function
def initialize_notebook():
    # Initialize everything
    merged_uploader, group_uploader, output_area = load_data_files()
    
    # Create container for visualization and export sections
    visualization_output = widgets.Output()
    export_output = widgets.Output()
    
    with visualization_output:
        plot_interactive_correlation()
    
    with export_output:
        setup_correlation_exports()
    
    # Create section headers
    upload_header = widgets.HTML("<h2>Data Upload</h2>")
    plot_header = widgets.HTML("<h2>Correlation Plot</h2>")
    export_header = widgets.HTML("<h2>Export Options</h2>")
    
    # Create upload section with both uploaders
    upload_section = widgets.VBox([
        upload_header,
        merged_uploader,  # Direct widget, not wrapped in HTML
        group_uploader,   # Direct widget, not wrapped in HTML
        output_area
    ])
    
    # Create main container
    main_container = widgets.VBox([
        upload_section,
        plot_header,
        visualization_output,
        export_header,
        export_output
    ])
    
    # Display the main container
    display(main_container)

# Cell 3: Run everything
# Create an instance of the data manager (if not already created)
if 'data_manager' not in globals():
    data_manager = DataManager()

# Initialize the notebook interface
initialize_notebook()

VBox(children=(VBox(children=(HTML(value='<h2>Data Upload</h2>'), FileUpload(value=(), accept='.csv,.txt,.tsv,â€¦