In [None]:
import pandas as pd
import numpy as np

# Assuming your DataFrame is named 'threshold_df' and already loaded
# If not, you would load it like this:
# threshold_df = pd.read_csv('your_file.csv')

loc_dict = {
    'Cytoplasm': ['7_0', '7_3'],
    'Nucleus': ['7_0', '7_3'],
    'Plastid': ['7_0', '7_3', '7_9', '8_2'],
    'Mitochondria': ['7_6', '7_9'],
    'Peroxisome': ['8_2']
}

def check_matching_prediction(row):
    locations = row['localisation'].split('|')
    for loc in locations:
        if loc in loc_dict:
            columns_to_check = loc_dict[loc]
            if any(row[col] != '0' and not pd.isna(row[col]) for col in columns_to_check):
                return 'Yes'
    return 'No'

# Specific clades to check
clades_to_check = ['AHLa', 'AHLb', 'Basal_AHL', 'Algal_AHL']

# Columns to check for Cytoplasm
cytoplasm_columns = ['7_0', '7_3', '7_6']

def check_AHL_prediction(row):
    # Check if the clade is in the list of clades to check
    if row['clade'] in clades_to_check:
        # Check if 'Cytoplasm' is in the localisation
        if 'Cytoplasm' in row['localisation'].split('|'):
            # Check if any of the cytoplasm columns have a non-'0' and non-NaN value
            if any(row[col] != '0' and not pd.isna(row[col]) for col in cytoplasm_columns):
                return 'Yes'
            else:
                return 'No'
    # If not in clades_to_check, return the existing prediction
    return row['matching_prediction']

# Apply the first prediction
threshold_df['matching_prediction'] = threshold_df.apply(check_matching_prediction, axis=1)

# Apply the second prediction, which will only modify rows where clade is in clades_to_check
threshold_df['matching_prediction'] = threshold_df.apply(check_AHL_prediction, axis=1)

# Display the result
print(threshold_df[['seq_id', 'clade', 'localisation', 'matching_prediction']])

In [None]:
import numpy as np
import plotly.graph_objects as go

def read_potential_file(filename):
    counts = None
    origin = None
    deltas = []
    potentials = []
    
    with open(filename, 'r') as f:
        for line in f:
            if line.startswith('object 1'):
                counts = [int(x) for x in line.split('counts')[1].split()]
            elif line.startswith('origin'):
                origin = [float(x) for x in line.split()[1:]]
            elif line.startswith('delta'):
                delta = [float(x) for x in line.split()[1:]]
                deltas.append(delta)
            elif line[0].isdigit() or line[0] == '-':
                values = [float(x) for x in line.split()]
                potentials.extend(values)
    
    return np.array(counts), np.array(origin), np.array(deltas), np.array(potentials)

def create_grid_coordinates(counts, origin, deltas):
    x = np.linspace(origin[0], origin[0] + counts[0] * deltas[0][0], counts[0])
    y = np.linspace(origin[1], origin[1] + counts[1] * deltas[1][1], counts[1])
    z = np.linspace(origin[2], origin[2] + counts[2] * deltas[2][2], counts[2])
    
    X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
    return X, Y, Z

def visualize_potentials(filename, sample_rate=5, filter_range=(-0.1, 0.1), 
                        point_style='dynamic', normal_point_size=4, 
                        range_point_size=1, min_dynamic_size=2, max_dynamic_size=8,
                        opacity=0.8, use_gray=True):
    """
    Visualization with multiple point style options and customizable point sizes and colors.
    
    Parameters:
    -----------
    filename : str
        Path to the potential data file
    sample_rate : int
        Sample every nth point to reduce plot size
    filter_range : tuple
        (min, max) range of values to filter or show as small points
    point_style : str
        'dynamic' - Point size scales with absolute potential value
        'hide' - Don't show values in filter_range
        'small' - Show values in filter_range with specified size
    normal_point_size : int
        Size of points outside filter_range (for 'hide'/'small' styles)
    range_point_size : int
        Size of points within filter_range (if point_style='small')
    min_dynamic_size : int
        Minimum point size for dynamic sizing
    max_dynamic_size : int
        Maximum point size for dynamic sizing
    opacity : float
        Opacity of points (0-1)
    use_gray : bool
        If True, use light gray for points in range; if False, keep viridis color scale
    """
    # Read data
    counts, origin, deltas, potentials = read_potential_file(filename)
    potentials = potentials.reshape(counts)
    X, Y, Z = create_grid_coordinates(counts, origin, deltas)
    
    # Sample the data
    X = X[::sample_rate, ::sample_rate, ::sample_rate]
    Y = Y[::sample_rate, ::sample_rate, ::sample_rate]
    Z = Z[::sample_rate, ::sample_rate, ::sample_rate]
    potentials = potentials[::sample_rate, ::sample_rate, ::sample_rate]
    
    # Create mask for values outside filter range
    outside_range = (potentials < filter_range[0]) | (potentials > filter_range[1])
    
    # Create figure
    fig = go.Figure()
    
    if point_style == 'dynamic':
        # Keep all points but scale their sizes
        sign = np.sign(potentials)
        log_potentials = sign * np.log1p(np.abs(potentials))
        
        # Calculate point sizes based on absolute values
        abs_potentials = np.abs(potentials)
        max_abs_potential = np.max(abs_potentials)
        normalized_sizes = (abs_potentials / max_abs_potential)
        point_sizes = min_dynamic_size + (max_dynamic_size - min_dynamic_size) * normalized_sizes
        
        scatter = go.Scatter3d(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            mode='markers',
            marker=dict(
                size=point_sizes.flatten(),
                color=log_potentials.flatten(),
                colorscale='viridis',
                opacity=opacity,
                colorbar=dict(
                    title='Log Potential (kT/e)',
                    ticktext=[f'{v:.1e}' for v in np.linspace(potentials.min(), potentials.max(), 6)],
                    tickvals=np.linspace(log_potentials.min(), log_potentials.max(), 6)
                )
            ),
            hovertemplate='x: %{x:.2f}<br>y: %{y:.2f}<br>z: %{z:.2f}<br>potential: %{text:.3e}<br>point size: %{marker.size:.1f}<extra></extra>',
            text=potentials.flatten()
        )
        fig.add_trace(scatter)
        
    else:
        # Process and plot values outside filter range
        X_outside = X[outside_range]
        Y_outside = Y[outside_range]
        Z_outside = Z[outside_range]
        potentials_outside = potentials[outside_range]
        
        # Log transform the values
        sign_outside = np.sign(potentials_outside)
        log_potentials_outside = sign_outside * np.log1p(np.abs(potentials_outside))
        
        scatter_outside = go.Scatter3d(
            x=X_outside.flatten(),
            y=Y_outside.flatten(),
            z=Z_outside.flatten(),
            mode='markers',
            marker=dict(
                size=normal_point_size,
                color=log_potentials_outside.flatten(),
                colorscale='viridis',
                opacity=opacity,
                colorbar=dict(
                    title='Log Potential (kT/e)',
                    ticktext=[f'{v:.1e}' for v in np.linspace(potentials_outside.min(), potentials_outside.max(), 6)],
                    tickvals=np.linspace(log_potentials_outside.min(), log_potentials_outside.max(), 6)
                )
            ),
            name='Outside range',
            hovertemplate='x: %{x:.2f}<br>y: %{y:.2f}<br>z: %{z:.2f}<br>potential: %{text:.3e}<extra></extra>',
            text=potentials_outside.flatten()
        )
        fig.add_trace(scatter_outside)
        
        # Process and plot values within filter range if point_style is 'small'
        if point_style == 'small':
            within_range = ~outside_range
            X_within = X[within_range]
            Y_within = Y[within_range]
            Z_within = Z[within_range]
            potentials_within = potentials[within_range]
            
            # Calculate colors for within-range points
            if use_gray:
                color = 'lightgray'
                showscale = False
            else:
                sign_within = np.sign(potentials_within)
                color = sign_within * np.log1p(np.abs(potentials_within))
                showscale = False  # Don't show second colorbar
            
            scatter_within = go.Scatter3d(
                x=X_within.flatten(),
                y=Y_within.flatten(),
                z=Z_within.flatten(),
                mode='markers',
                marker=dict(
                    size=range_point_size,
                    color=color if use_gray else color.flatten(),
                    colorscale='viridis' if not use_gray else None,
                    opacity=opacity * 0.5,
                    showscale=showscale
                ),
                name='Within range',
                hovertemplate='x: %{x:.2f}<br>y: %{y:.2f}<br>z: %{z:.2f}<br>potential: %{text:.3e}<extra></extra>',
                text=potentials_within.flatten()
            )
            fig.add_trace(scatter_within)
    
    # Update layout
    style_text = {
        'dynamic': 'point size scaled with potential magnitude',
        'hide': f'values {filter_range[0]:.1e} to {filter_range[1]:.1e} hidden',
        'small': f'values {filter_range[0]:.1e} to {filter_range[1]:.1e} shown with size {range_point_size}'
    }
    color_text = ' (gray)' if use_gray and point_style == 'small' else ''
    
    fig.update_layout(
        title=f'3D Electrostatic Potential Map<br>{style_text[point_style]}{color_text}',
        legend=dict(
            x=0,          # Move legend to the left side
            y=1,          # Keep at top
            xanchor='left',
            yanchor='top',
            bgcolor='rgba(255,255,255,0.7)'  # Semi-transparent white background
        ),
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.5)
            ),
            
        ),
        width=900,
        height=900
    )
    
    return fig

# Example usage:
# Dynamic point sizing
# fig = visualize_potentials('potential.dx', point_style='dynamic', 
#                          min_dynamic_size=2, max_dynamic_size=8)

# Hide values between -0.1 and 0.1
# fig = visualize_potentials('potential.dx', point_style='hide',
#                          filter_range=(-0.1, 0.1))

# Show values between -0.1 and 0.1 with specific point sizes (gray)
# fig = visualize_potentials('potential.dx', point_style='small',
#                          filter_range=(-0.1, 0.1),
#                          normal_point_size=4,
#                          range_point_size=1,
#                          use_gray=True)

# Show values between -0.1 and 0.1 with specific point sizes (viridis)
# fig = visualize_potentials('potential.dx', point_style='small',
#                          filter_range=(-0.1, 0.1),
#                          normal_point_size=4,
#                          range_point_size=1,
#                          use_gray=False)

# fig.show()

### Main figure code

In [None]:
def create_small_multiples(df):
    """
    Creates small multiple heatmaps showing localization patterns for each clade,
    aggregated by taxa, with improved layout
    """
    # Get unique clades and compartments
    clades = df['clade'].unique()
    compartments = ['Cytoplasm', 'Nucleus', 'Plastid', 'Mitochondria', 'Peroxisome']
    
    # Create subplot grid with increased spacing
    n_clades = len(clades)
    n_cols = min(3, n_clades)
    n_rows = (n_clades + n_cols - 1) // n_cols
    
    # Increase figure size significantly and adjust spacing
    fig = plt.figure(figsize=(24, 7*n_rows))
    gs = fig.add_gridspec(n_rows, n_cols, hspace=0.5, wspace=0.4)
    axes = gs.subplots()
    
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    elif n_cols == 1:
        axes = axes.reshape(-1, 1)
    
    # Create heatmap for each clade
    for idx, clade in enumerate(clades):
        row = idx // n_cols
        col = idx % n_cols
        
        # Filter data for this clade
        clade_df = df[df['clade'] == clade]
        
        # Calculate proportions by taxa for each compartment
        taxa_proportions = []
        unique_taxa = clade_df['taxa'].unique()
        
        for compartment in compartments:
            compartment_proportions = []
            for taxon in unique_taxa:
                taxon_data = clade_df[clade_df['taxa'] == taxon][compartment]
                yes_count = (taxon_data == 'Yes').sum()
                no_count = (taxon_data == 'No').sum()
                total = yes_count + no_count
                
                if total > 0:
                    proportion = (yes_count - no_count) / total
                else:
                    proportion = 0
                
                compartment_proportions.append(proportion)
            taxa_proportions.append(compartment_proportions)
        
        # Convert to numpy array for heatmap
        heatmap_data = np.array(taxa_proportions)
        
        # Create heatmap with improved parameters
        sns.heatmap(heatmap_data,
                   cmap='RdBu',
                   center=0,
                   yticklabels=compartments,
                   xticklabels=unique_taxa,
                   ax=axes[row, col],
                   cbar_kws={'label': 'Proportion (Yes-No)/(Yes+No)'},
                   linewidths=0)
        
        # Improve title and labels
        axes[row, col].set_title(f'Clade: {clade}', pad=20, fontsize=14, fontweight='bold')
        axes[row, col].set_xlabel('Taxa', labelpad=15)
        axes[row, col].set_ylabel('Compartments', labelpad=15)
        
        # Rotate and align x-axis labels
        axes[row, col].set_xticklabels(
            axes[row, col].get_xticklabels(),
            rotation=45,
            horizontalalignment='right',
            fontsize=10
        )
        
        # Adjust y-axis label size
        axes[row, col].set_yticklabels(
            axes[row, col].get_yticklabels(),
            fontsize=12
        )
        
        # Add gridlines
        axes[row, col].grid(False)
        
    # Remove empty subplots if any
    if n_rows * n_cols > n_clades:
        for idx in range(n_clades, n_rows * n_cols):
            row = idx // n_cols
            col = idx % n_cols
            fig.delaxes(axes[row, col])
    
    # Add a title to the entire figure
    fig.suptitle('Protein Localization Patterns Across Clades and Taxa', 
                fontsize=16, 
                fontweight='bold', 
                y=1.02)
    
    # Adjust layout to prevent overlapping
    plt.tight_layout()
    return fig

# Set the style for better visibility
plt.style.use('seaborn-whitegrid')
sns.set_context("notebook", font_scale=1.2)

### Upset plots for main figure

In [None]:
def create_upset_plot(df, data_type='Yes'):
    """
    Creates an UpSet-style plot with statistical analysis including:
    - Compartment percentages per clade
    - Chi-square test results
    - Jaccard similarity indices between clades
    """
    from scipy.stats import chi2_contingency
    import itertools
    
    compartments = ['Cytoplasm', 'Nucleus', 'Plastid', 'Mitochondria', 'Peroxisome']
    
    # Separate dataframes for each group
    df_cnp = df[df['clade'] == 'CNP']
    df_sal = df[df['clade'] == 'SAL']
    df_ahl = df[df['clade'].isin(['AHLa', 'AHLb', 'Basal_AHL', 'Algal_AHL'])]
    
    def calculate_compartment_percentages(sub_df):
        total_proteins = len(sub_df)
        percentages = {}
        for comp in compartments:
            count = len(sub_df[sub_df[comp] == data_type])
            percentages[comp] = (count / total_proteins) * 100
        return percentages
    
    def calculate_jaccard_similarity(df1, df2):
        def get_combinations(df):
            combinations = set()
            for _, row in df.iterrows():
                current = frozenset(comp for comp in compartments 
                                  if pd.notna(row[comp]) and row[comp] == data_type)
                if current:
                    combinations.add(current)
            return combinations
        
        set1 = get_combinations(df1)
        set2 = get_combinations(df2)
        intersection = len(set1.intersection(set2))
        union = len(set1.union(set2))
        return intersection / union if union > 0 else 0
    
    def perform_chi_square_test(df):
        """
        Performs chi-square test with handling for zero counts.
        Only includes compartments that have at least one protein in each clade.
        """
        # Create contingency table
        contingency_table = []
        clades = df['clade'].unique()
        
        # First, check which compartments have proteins in all clades
        valid_compartments = []
        for comp in compartments:
            has_proteins = True
            for clade in clades:
                clade_df = df[df['clade'] == clade]
                count = len(clade_df[clade_df[comp] == data_type])
                if count == 0:
                    has_proteins = False
                    break
            if has_proteins:
                valid_compartments.append(comp)
        
        # If no valid compartments for chi-square test, return None
        if not valid_compartments:
            return None, None, None
        
        # Create contingency table with only valid compartments
        for clade in clades:
            clade_df = df[df['clade'] == clade]
            row = []
            for comp in valid_compartments:
                count = len(clade_df[clade_df[comp] == data_type])
                row.append(count)
            contingency_table.append(row)
        
        # Perform chi-square test
        chi2, p_value, _, _ = chi2_contingency(contingency_table)
        return chi2, p_value, valid_compartments
    
    def get_combinations_and_counts(sub_df):
        all_combinations = []
        counts = {}
        clade_counts = {}
        
        for _, row in sub_df.iterrows():
            current_locations = set()
            valid_row = False
            for comp in compartments:
                if pd.notna(row[comp]) and row[comp] == data_type:
                    current_locations.add(comp)
                    valid_row = True
            
            if valid_row:
                current_set = frozenset(current_locations)
                counts[current_set] = counts.get(current_set, 0) + 1
                if current_set not in all_combinations:
                    all_combinations.append(current_set)
                    clade_counts[current_set] = {clade: 0 for clade in sub_df['clade'].unique()}
                clade_counts[current_set][row['clade']] += 1
        
        all_combinations.sort(key=lambda x: counts[x], reverse=True)
        return all_combinations, counts, clade_counts
    
    # Create figure with enough space for plots and stats
    fig = plt.figure(figsize=(15, 28))
    gs = fig.add_gridspec(8, 2, width_ratios=[4, 1], 
                         height_ratios=[1.5, 1, 1.5, 1, 1.5, 1, 2, 1], 
                         hspace=0.3, wspace=0.1)
    
    # Colors for AHL clades
    ahl_colors = {
        'AHLa': '#2a788e',
        'AHLb': '#22a884',
        'Basal_AHL': '#414487',
        'Algal_AHL': '#7ad151'
    }
    
    def create_subplot_pair(data, row_idx, title):
    combinations, counts, clade_counts = get_combinations_and_counts(data)
    n_combinations = len(combinations)
    
    if n_combinations == 0:
        return
    
    # Create nested GridSpec for this pair
    nested_gs = gs[row_idx].subgridspec(2, 1, height_ratios=[1.5, 1], hspace=0.05)
    
    # Define bar_width here so it's available for both plots
    bar_width = 0.8
    
    # Top subplot for intersection sizes
    ax_sets = fig.add_subplot(nested_gs[0])
    ax_sets_right = fig.add_subplot(gs[row_idx, 1])
    
    x_positions = np.arange(n_combinations)
    
    if title in ['CNP', 'SAL']:
        # Single color bars for CNP and SAL
        color = '#440154' if title == 'CNP' else '#fde725'
        heights = [counts[comb] for comb in combinations]
        ax_sets.bar(x_positions, heights, color=color, width=bar_width)
        
        # Add value labels
        for i, size in enumerate(heights):
            ax_sets.text(i, size, str(size), ha='center', va='bottom')
    else:
        # Stacked bars for AHL clades
        bottom = np.zeros(n_combinations)
        for clade in sorted(data['clade'].unique()):
            heights = [clade_counts[comb][clade] for comb in combinations]
            ax_sets.bar(x_positions, heights, bottom=bottom, 
                      label=clade, color=ahl_colors[clade], width=bar_width)
            bottom += heights
        
        # Add total value labels
        totals = [counts[comb] for comb in combinations]
        for i, total in enumerate(totals):
            ax_sets.text(i, total, str(total), ha='center', va='bottom')
        
        # Add legend
        ax_sets.legend(bbox_to_anchor=(1.3, 1), loc='upper left')
    
    ax_sets.grid(True, axis='y', linestyle='--', alpha=0.7)
    ax_sets.set_xticks([])
    ax_sets.set_ylabel('Number of Proteins')
    ax_sets.set_xlim(-0.6, n_combinations - 0.4)
    ax_sets.set_title(f'{title} Compartment Distribution')
    
    # Right subplot for spacing
    ax_sets_right.axis('off')
    
    # Bottom subplot for matrix
    ax_matrix = fig.add_subplot(nested_gs[1])
    
    # Create matrix
    matrix_data = np.zeros((len(compartments), n_combinations))
    for i, comp in enumerate(compartments):
        for j, comb in enumerate(combinations):
            matrix_data[i, j] = 1 if comp in comb else 0
    
    # Plot matrix using rectangles
    for i in range(len(compartments)):
        for j in range(n_combinations):
            if matrix_data[i, j]:
                rect = plt.Rectangle((j - bar_width/2, i-0.4), bar_width, 0.8,
                                  facecolor='black')
                ax_matrix.add_patch(rect)
    
    # Configure matrix plot
    ax_matrix.set_xlim(-0.6, n_combinations - 0.4)
    ax_matrix.set_ylim(-0.5, len(compartments) - 0.5)
    ax_matrix.set_yticks(np.arange(len(compartments)))
    ax_matrix.set_yticklabels(compartments)
    
    # Set x-tick labels
    ax_matrix.set_xticks(x_positions)
    combination_labels = ['+'.join(sorted(comp[0] for comp in comb)) for comb in combinations]
    ax_matrix.set_xticklabels(combination_labels, rotation=45, ha='right')
    
    # Add grid
    ax_matrix.grid(True, which='major', linestyle='-', alpha=0.3)
    
    # Create subplots for each group
    create_subplot_pair(df_cnp, 0, 'CNP')
    create_subplot_pair(df_sal, 1, 'SAL')
    create_subplot_pair(df_ahl, 2, 'AHL Clades')
    
    # Create statistics subplot
    stats_ax = fig.add_subplot(gs[6:8, :])
    stats_ax.axis('off')
    
    # Initialize y_position for text placement
    y_position = 0.95
    
    # Display title for compartment percentages
    stats_ax.text(0.02, y_position, 'Compartment Percentages by Clade:', 
                 fontsize=12, fontweight='bold')
    
    # Display percentages for CNP and SAL
    y_position -= 0.1
    for group_df, name in [(df_cnp, 'CNP'), (df_sal, 'SAL')]:
        percentages = calculate_compartment_percentages(group_df)
        stats_text = f"{name}: " + ", ".join(
            f"{comp}: {percentages[comp]:.1f}%" for comp in compartments)
        stats_ax.text(0.02, y_position, stats_text)
        y_position -= 0.05
    
    # Display percentages for AHL clades
    for clade in ['AHLa', 'AHLb', 'Basal_AHL', 'Algal_AHL']:
        clade_df = df[df['clade'] == clade]
        percentages = calculate_compartment_percentages(clade_df)
        stats_text = f"{clade}: " + ", ".join(
            f"{comp}: {percentages[comp]:.1f}%" for comp in compartments)
        stats_ax.text(0.02, y_position, stats_text)
        y_position -= 0.05
    
    # Add chi-square test results
    y_position -= 0.1
    chi2, p_value, valid_comps = perform_chi_square_test(df)
    if chi2 is not None:
        stats_ax.text(0.02, y_position, 
                     f"Chi-square test for independence of compartment distribution across clades:",
                     fontsize=12, fontweight='bold')
        y_position -= 0.05
        stats_ax.text(0.02, y_position, 
                     f"χ² = {chi2:.1f}, p-value = {p_value:.2e}")
        y_position -= 0.05
        stats_ax.text(0.02, y_position, 
                     f"(Test performed on compartments: {', '.join(valid_comps)})")
    else:
        stats_ax.text(0.02, y_position, 
                     "Chi-square test could not be performed due to zero counts",
                     fontsize=12, fontweight='bold')
    
    # Add Jaccard similarity indices
    y_position -= 0.1
    stats_ax.text(0.02, y_position, 
                 "Jaccard Similarity Indices (measures pattern similarity between clades):",
                 fontsize=12, fontweight='bold')
    y_position -= 0.05
    stats_ax.text(0.02, y_position, 
                 "(0 = completely different patterns, 1 = identical patterns)")
    
    # Calculate pairwise Jaccard indices
    clades = ['CNP', 'SAL', 'AHLa', 'AHLb', 'Basal_AHL', 'Algal_AHL']
    for clade1, clade2 in itertools.combinations(clades, 2):
        y_position -= 0.05
        df1 = df[df['clade'] == clade1]
        df2 = df[df['clade'] == clade2]
        jaccard = calculate_jaccard_similarity(df1, df2)
        stats_ax.text(0.02, y_position, 
                     f"{clade1} vs {clade2}: {jaccard:.3f}")
    
    # Add legend for compartment abbreviations
    legend_text = "Legend: " + ", ".join(f"{comp[0]}={comp}" for comp in compartments)
    fig.text(0.1, 0.01, legend_text, fontsize=8)
    
    plt.subplots_adjust(bottom=0.15, right=0.85)
    
    return fig

### Updated main figure code

In [None]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from itertools import combinations

def create_small_multiples(df, min_samples=0):
    """
    Creates small multiple heatmaps showing localization patterns for each clade,
    aggregated by taxa, with improved layout and custom taxa order.
    Only includes taxa with sample counts >= min_samples.
    
    Args:
        df: Input DataFrame
        min_samples: Minimum number of samples required for a taxon to be included
    """
    # Define custom taxa order
    taxa_order = ['Chromista','Rhodophyta','Glaucophyta','Chlorophyta','Streptophyte_algae',
                  'Hornworts','Liverworts','Mosses','Lycophytes', 'Monilophytes','Ferns','Cycadales',
                  'Ginkgoales','Gnetales', 'Gymnos', 'Gymnosperms', 'Pinales','ANAGrade','Magnoliids', 'Monocots',
                  'Monocots/Alismatales','Monocots/Diosales','Monocots/Pandanales',
                  'Monocots/Liliales','Monocots/Asparagales','Monocots/Commelinids',
                  'Eudicots/Ranunculales','Eudicots','Eudicots/Caryophyllales',
                  'Eudicots/Asterids','Eudicots/Santalales','Eudicots/Saxifragales',
                  'Eudicots/Rosids']
    
    # Get unique clades and compartments
    clades = df['clade'].unique()
    compartments = ['Cytoplasm', 'Nucleus', 'Plastid', 'Mitochondria', 'Peroxisome']
    
    # Create subplot grid with increased spacing
    n_clades = len(clades)
    n_cols = min(3, n_clades)
    n_rows = (n_clades + n_cols - 1) // n_cols
    
    # Increase figure size and adjust spacing
    fig = plt.figure(figsize=(24, 7*n_rows))
    gs = fig.add_gridspec(n_rows, n_cols, hspace=0.8, wspace=0.8)
    axes = gs.subplots()
    
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    elif n_cols == 1:
        axes = axes.reshape(-1, 1)
    
    # Create heatmap for each clade
    for idx, clade in enumerate(clades):
        row = idx // n_cols
        col = idx % n_cols
        
        # Filter data for this clade
        clade_df = df[df['clade'] == clade]
        
        # Get sample counts for each taxon
        taxa_counts = clade_df['taxa'].value_counts()
        # Filter taxa that meet minimum sample requirement
        valid_taxa = taxa_counts[taxa_counts >= min_samples].index
        
        # Filter dataframe to only include valid taxa
        clade_df = clade_df[clade_df['taxa'].isin(valid_taxa)]
        
        # If no taxa meet the minimum sample requirement, skip this clade
        if len(valid_taxa) == 0:
            axes[row, col].text(0.5, 0.5, f'No taxa in {clade} have ≥{min_samples} samples',
                              ha='center', va='center')
            axes[row, col].axis('off')
            continue
        
        # Get unique taxa in this clade and sort according to custom order
        unique_taxa = clade_df['taxa'].unique()
        # Filter taxa_order to only include taxa present in this clade
        ordered_taxa = [taxa for taxa in taxa_order if taxa in unique_taxa]
        
        # Calculate proportions by taxa for each compartment
        taxa_proportions = []
        
        for compartment in compartments:
            compartment_proportions = []
            for taxon in ordered_taxa:
                taxon_data = clade_df[clade_df['taxa'] == taxon][compartment]
                yes_count = (taxon_data == 'Yes').sum()
                no_count = (taxon_data == 'No').sum()
                total = yes_count + no_count
                
                if total > 0:
                    proportion = (yes_count - no_count) / total
                else:
                    proportion = 0
                
                compartment_proportions.append(proportion)
            taxa_proportions.append(compartment_proportions)
        
        # Create labels with sample counts
        taxa_labels = []
        for taxon in ordered_taxa:
            sample_count = len(clade_df[clade_df['taxa'] == taxon])
            taxa_labels.append(f'{taxon} (n={sample_count})')

        # Convert to numpy array for heatmap
        heatmap_data = np.array(taxa_proportions)
        
        # Create heatmap with improved parameters
        sns.heatmap(heatmap_data,
                   cmap='PRGn',
                   vmin=-1,
                   vmax=1,
                   center=0,
                   yticklabels=compartments,
                   xticklabels=taxa_labels,
                   ax=axes[row, col],
                   cbar_kws={'label': 'Proportion pH-localisation match/mismatch'},
                   linewidths=0.5,
                   linecolor='white')
        
        # Improve title and labels
        axes[row, col].set_title(f'Clade: {clade}', pad=20, fontsize=14, fontweight='bold')
        axes[row, col].set_xlabel('Taxa', labelpad=15)
        axes[row, col].set_ylabel('Subcellular Compartments', labelpad=15)
        
        # Rotate and align x-axis labels
        axes[row, col].set_xticklabels(
            axes[row, col].get_xticklabels(),
            rotation=45,
            horizontalalignment='right',
            fontsize=14
        )
        
        # Adjust y-axis label size
        axes[row, col].set_yticklabels(
            axes[row, col].get_yticklabels(),
            fontsize=13
        )
        
        # Add gridlines
        axes[row, col].grid(False)
    
    # Remove empty subplots if any
    if n_rows * n_cols > n_clades:
        for idx in range(n_clades, n_rows * n_cols):
            row = idx // n_cols
            col = idx % n_cols
            fig.delaxes(axes[row, col])
    
    # Add a title to the entire figure
    fig.suptitle('Proportion of sequences in a particular taxa at a particular pH that have a predicted localisation that would likely complement AtSAL1 activity', 
                fontsize=16, 
                fontweight='bold', 
                y=1.02)
    
    # Adjust layout to prevent overlapping
    plt.tight_layout()
    return fig

def pH_heatmap(df, min_samples=0):
    """
    Creates small multiple heatmaps showing the proportion of non-NaN values
    for each combination of taxa and pH level within each clade.
    Only includes taxa with sample counts >= min_samples.
    
    Args:
        df: Input DataFrame
        min_samples: Minimum number of samples required for a taxon to be included
    """
    # Define custom taxa order
    taxa_order = ['Chromista','Rhodophyta','Glaucophyta','Chlorophyta', 'Streptophyta', 'Streptophyte_algae',
                  'Hornworts','Liverworts','Mosses','Lycophytes', 'Monilophytes','Ferns','Cycadales',
                  'Ginkgoales','Gnetales', 'Gymnos', 'Gymnosperms', 'Pinales','ANAGrade','Magnoliids', 'Monocots',
                  'Monocots/Alismatales','Monocots/Diosales','Monocots/Pandanales',
                  'Monocots/Liliales','Monocots/Asparagales','Monocots/Commelinids',
                  'Eudicots/Ranunculales','Eudicots','Eudicots/Caryophyllales',
                  'Eudicots/Asterids','Eudicots/Santalales','Eudicots/Saxifragales',
                  'Eudicots/Rosids']
    
    # Get unique clades and pH levels
    clades = df['clade'].unique()
    pH_levels = ['7_0', '7_3', '7_6', '7_9', '8_2']
    
    # Create nice pH labels by replacing underscore with decimal point
    pH_labels = [ph.replace('_', '.') for ph in pH_levels]
    
    # Create subplot grid with increased spacing
    n_clades = len(clades)
    n_cols = min(3, n_clades)
    n_rows = (n_clades + n_cols - 1) // n_cols
    
    # Increase figure size and adjust spacing
    fig = plt.figure(figsize=(24, 7*n_rows))
    gs = fig.add_gridspec(n_rows, n_cols, hspace=0.8, wspace=0.6)
    axes = gs.subplots()
    
    # Ensure axes is always 2D
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    elif n_cols == 1:
        axes = axes.reshape(-1, 1)
    
    # Create heatmap for each clade
    for idx, clade in enumerate(clades):
        row = idx // n_cols
        col = idx % n_cols
        
        # Filter data for this clade
        clade_df = df[df['clade'] == clade]
        
        # Get sample counts for each taxon
        taxa_counts = clade_df['taxa'].value_counts()
        # Filter taxa that meet minimum sample requirement
        valid_taxa = taxa_counts[taxa_counts >= min_samples].index
        
        # Filter dataframe to only include valid taxa
        clade_df = clade_df[clade_df['taxa'].isin(valid_taxa)]
        
        # If no taxa meet the minimum sample requirement, skip this clade
        if len(valid_taxa) == 0:
            axes[row, col].text(0.5, 0.5, f'No taxa in {clade} have ≥{min_samples} samples',
                              ha='center', va='center')
            axes[row, col].axis('off')
            continue
        
        # Get unique taxa in this clade and sort according to custom order
        unique_taxa = clade_df['taxa'].unique()
        # Filter taxa_order to only include taxa present in this clade
        ordered_taxa = [taxa for taxa in taxa_order if taxa in unique_taxa]
        
        # Create labels with sample counts
        taxa_labels = []
        for taxon in ordered_taxa:
            sample_count = len(clade_df[clade_df['taxa'] == taxon])
            taxa_labels.append(f'{taxon} (n={sample_count})')

        # Initialize matrix to store proportions
        heatmap_data = np.zeros((len(pH_levels), len(ordered_taxa)))
        
        # Calculate proportions for each taxa and pH level
        for i, pH in enumerate(pH_levels):
            for j, taxon in enumerate(ordered_taxa):
                taxon_data = clade_df[clade_df['taxa'] == taxon]
                total_rows = len(taxon_data)
                if total_rows > 0:
                    # Count non-NaN values and calculate proportion
                    non_nan_count = taxon_data[pH].notna().sum()
                    heatmap_data[i, j] = non_nan_count / total_rows
        
        # Create heatmap with improved parameters
        sns.heatmap(heatmap_data,
                   cmap='YlOrRd',
                   vmin=0,
                   vmax=1,
                   yticklabels=pH_labels,
                   xticklabels=taxa_labels,
                   ax=axes[row, col],
                   cbar_kws={'label': 'Proportion of complementing sequences'},
                   linewidths=0.5,
                   linecolor='white')
        
        # Improve title and labels
        axes[row, col].set_title(f'Clade: {clade}', pad=20, fontsize=14, fontweight='bold')
        axes[row, col].set_xlabel('Taxa', labelpad=15)
        axes[row, col].set_ylabel('pH', labelpad=15)
        
        # Rotate and align x-axis labels
        axes[row, col].set_xticklabels(
            axes[row, col].get_xticklabels(),
            rotation=45,
            horizontalalignment='right',
            fontsize=13
        )
        
        # Adjust y-axis label size
        axes[row, col].set_yticklabels(
            axes[row, col].get_yticklabels(),
            fontsize=14
        )
    
    # Remove empty subplots if any
    if n_rows * n_cols > n_clades:
        for idx in range(n_clades, n_rows * n_cols):
            row = idx // n_cols
            col = idx % n_cols
            fig.delaxes(axes[row, col])
    
    # Add a title to the entire figure
    fig.suptitle('Proportion of sequences in a particular taxa that have a value at a particular pH, likely to complement AtSAL1',
                 fontsize=16,
                 fontweight='bold',
                 y=1.02)
    
    # Adjust layout to prevent overlapping
    plt.tight_layout()
    return fig

# Set the style for better visibility
plt.style.use("seaborn-v0_8-whitegrid")
sns.set_context("notebook", font_scale=1.2)

def create_upset_plot(df):
    """
    Creates an UpSet-style plot with statistical analysis including:
    - Compartment percentages per clade
    - Jaccard similarity indices between clades
    Shows both matching and mismatching predictions side by side
    """
    import itertools
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    
    compartments = ['Cytoplasm', 'Nucleus', 'Plastid', 'Mitochondria', 'Peroxisome']
    compartment_abbrev = {'Cytoplasm': 'C', 'Nucleus': 'N', 'Plastid': 'P', 
                         'Mitochondria': 'M', 'Peroxisome': 'Px'}
    
    # Separate dataframes for each group
    df_cnp = df[df['clade'] == 'CNP']
    df_sal = df[df['clade'] == 'SAL']
    df_ahl = df[df['clade'].isin(['AHLa', 'AHLb', 'Basal_AHL', 'Algal_AHL'])]
    
    # Colors for AHL clades
    ahl_colors = {
        'AHLa': '#2a788e',
        'AHLb': '#22a884',
        'Basal_AHL': '#414487',
        'Algal_AHL': '#7ad151'
    }
    
    # Initialize max_bar_height as a list to allow modification inside nested functions
    max_bar_height = [0]
    
    def calculate_compartment_percentages(sub_df):
        total_proteins = len(sub_df)
        percentages = {}
        for comp in compartments:
            count = len(sub_df[sub_df[comp] == 'Yes'])
            percentages[comp] = (count / total_proteins) * 100
        return percentages
    
    def calculate_jaccard_similarity(df1, df2):
        def get_combinations(df):
            combinations = set()
            for _, row in df.iterrows():
                current = frozenset(comp for comp in compartments 
                                  if pd.notna(row[comp]) and row[comp] == 'Yes')
                if current:
                    combinations.add(current)
            return combinations
        
        set1 = get_combinations(df1)
        set2 = get_combinations(df2)
        intersection = len(set1.intersection(set2))
        union = len(set1.union(set2))
        return intersection / union if union > 0 else 0
    
    def get_combinations_and_counts(sub_df, data_type='Yes'):
        all_combinations = []
        counts = {}
        clade_counts = {}
        
        for _, row in sub_df.iterrows():
            current_locations = set()
            valid_row = False
            for comp in compartments:
                if pd.notna(row[comp]) and row[comp] == data_type:
                    current_locations.add(comp)
                    valid_row = True
            
            if valid_row:
                current_set = frozenset(current_locations)
                counts[current_set] = counts.get(current_set, 0) + 1
                if current_set not in all_combinations:
                    all_combinations.append(current_set)
                    clade_counts[current_set] = {clade: 0 for clade in sub_df['clade'].unique()}
                clade_counts[current_set][row['clade']] += 1
        
        all_combinations.sort(key=lambda x: counts[x], reverse=True)
        return all_combinations, counts, clade_counts
    
    # Create figure with modified gridspec
    fig = plt.figure(figsize=(15, 28))  # Doubled width to accommodate side-by-side plots
    
    # Create a gridspec with 4 rows - one for each plot plus stats
    gs = fig.add_gridspec(4, 1, height_ratios=[2, 2, 2, 2], hspace=0.4)
    
    # Create nested gridspecs for each plot section
    gs_plots = [gs[i].subgridspec(1, 2, width_ratios=[1, 1], wspace=0.3) for i in range(3)]
    
    def create_subplot_pair(data, gs_section, title):
        for data_type, subplot_gs in zip(['Yes', 'No'], [gs_section[0], gs_section[1]]):
            combinations, counts, clade_counts = get_combinations_and_counts(data, data_type)
            n_combinations = len(combinations)
            
            if n_combinations == 0:
                continue
            
            nested_gs = subplot_gs.subgridspec(2, 1, height_ratios=[1.5, 1], hspace=0.05)
            
            bar_width = 0.8
            
            ax_sets = fig.add_subplot(nested_gs[0])
            
            x_positions = np.arange(n_combinations)
            
            if title in ['CNP', 'SAL']:
                color = '#440154' if title == 'CNP' else '#fde725'
                heights = [counts[comb] for comb in combinations]
                ax_sets.bar(x_positions, heights, color=color, width=bar_width)
                max_bar_height[0] = max(max_bar_height[0], max(heights) if heights else 0)
            else:
                bottom = np.zeros(n_combinations)
                total_heights = np.zeros(n_combinations)
                for clade in sorted(data['clade'].unique()):
                    heights = [clade_counts[comb][clade] for comb in combinations]
                    ax_sets.bar(x_positions, heights, bottom=bottom, 
                            label=clade, color=ahl_colors[clade], width=bar_width)
                    bottom += heights
                    total_heights += heights
                
                max_bar_height[0] = max(max_bar_height[0], max(total_heights) if len(total_heights) > 0 else 0)
            
            if data_type == 'Yes':
                plot_title = f'{title} Prediction matching pH'
            else:
                plot_title = f'{title} Prediction mismatched to pH'
            
            ax_sets.grid(True, axis='y', linestyle='--', alpha=0.7)
            ax_sets.set_xticks([])
            ax_sets.set_ylabel('Number of Proteins')
            ax_sets.set_xlim(-0.6, n_combinations - 0.4)
            ax_sets.set_title(plot_title)
            
            if title == 'AHL Clades':
                ax_sets.legend(bbox_to_anchor=(1.3, 1), loc='upper left')
            
            ax_matrix = fig.add_subplot(nested_gs[1])
            
            matrix_data = np.zeros((len(compartments), n_combinations))
            for i, comp in enumerate(compartments):
                for j, comb in enumerate(combinations):
                    matrix_data[i, j] = 1 if comp in comb else 0
            
            for i in range(len(compartments)):
                for j in range(n_combinations):
                    if matrix_data[i, j]:
                        rect = plt.Rectangle((j - bar_width/2, i-0.4), bar_width, 0.8,
                                        facecolor='black')
                        ax_matrix.add_patch(rect)
            
            ax_matrix.set_xlim(-0.6, n_combinations - 0.4)
            ax_matrix.set_ylim(-0.5, len(compartments) - 0.5)
            ax_matrix.set_yticks(np.arange(len(compartments)))
            ax_matrix.set_yticklabels(compartments)
            
            ax_matrix.set_xticks(x_positions)
            combination_labels = ['+'.join(sorted(compartment_abbrev[comp] for comp in comb)) 
                                for comb in combinations]
            ax_matrix.set_xticklabels(combination_labels, rotation=45, ha='right')
            
            ax_matrix.grid(True, which='major', linestyle='-', alpha=0.3)
    
    # Create subplots in desired order: SAL, AHL, CNP
    create_subplot_pair(df_sal, gs_plots[0], 'SAL')
    create_subplot_pair(df_ahl, gs_plots[1], 'AHL Clades')
    create_subplot_pair(df_cnp, gs_plots[2], 'CNP')
    
    # Set all bar plots to the same y-axis limit
    for ax in fig.get_axes():
        if ax.get_ylabel() == 'Number of Proteins':
            ax.set_ylim(0, max_bar_height[0] * 1.1)  # Add 10% padding
    
    # Create statistics subplot in the bottom row
    stats_ax = fig.add_subplot(gs[3])
    stats_ax.axis('off')
    
    y_position = 0.95
    
    stats_ax.text(0.02, y_position, 'Compartment Percentages by Clade:', 
                 fontsize=12, fontweight='bold')
    
    y_position -= 0.1
    for group_df, name in [(df_cnp, 'CNP'), (df_sal, 'SAL')]:
        percentages = calculate_compartment_percentages(group_df)
        stats_text = f"{name}: " + ", ".join(
            f"{comp}: {percentages[comp]:.1f}%" for comp in compartments)
        stats_ax.text(0.02, y_position, stats_text)
        y_position -= 0.05
    
    for clade in ['AHLa', 'AHLb', 'Basal_AHL', 'Algal_AHL']:
        clade_df = df[df['clade'] == clade]
        percentages = calculate_compartment_percentages(clade_df)
        stats_text = f"{clade}: " + ", ".join(
            f"{comp}: {percentages[comp]:.1f}%" for comp in compartments)
        stats_ax.text(0.02, y_position, stats_text)
        y_position -= 0.05
    
    y_position -= 0.1
    stats_ax.text(0.02, y_position, 
                 "Jaccard Similarity Indices (measures pattern similarity between clades):",
                 fontsize=12, fontweight='bold')
    y_position -= 0.05
    stats_ax.text(0.02, y_position, 
                 "(0 = completely different patterns, 1 = identical patterns)")
    
    clades = ['CNP', 'SAL', 'AHLa', 'AHLb', 'Basal_AHL', 'Algal_AHL']
    for clade1, clade2 in itertools.combinations(clades, 2):
        y_position -= 0.05
        df1 = df[df['clade'] == clade1]
        df2 = df[df['clade'] == clade2]
        jaccard = calculate_jaccard_similarity(df1, df2)
        stats_ax.text(0.02, y_position, 
                     f"{clade1} vs {clade2}: {jaccard:.3f}")
    
    legend_text = "Legend: " + ", ".join(f"{compartment_abbrev[comp]}={comp}" for comp in compartments)
    fig.text(0.1, 0.01, legend_text, fontsize=8)
    
    plt.subplots_adjust(bottom=0.15, right=0.85, left=0.15)
    
    return fig

def visualize_protein_localizations(df):
    """
    Creates all visualizations for the protein localization data
    """
 
    print("\nCreating complementation plot...")
    pH_multiples = pH_heatmap(df)
    pH_multiples.show()       
    print('This plot shows what proportion of sequences in a particular taxa have a value at a particular pH')
    
    # Create small multiples
    print("\nCreating pH-localisation plot...")
    small_multiples = create_small_multiples(df)
    small_multiples.show()
    print('This plot shows what proportion of sequences in a particular taxa at a particular pH have a predicted localisation that would likely complement AtSAL1 activity')

    
    # Create upset plot
    print("\nCreating upset plot...")
    upset_plot = create_upset_plot(df)
    upset_plot.show()

### V3 of above

In [None]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from itertools import combinations

# Constants for visualization
TAXA_ORDER = [
    'Chromista', 'Rhodophyta', 'Glaucophyta', 'Chlorophyta', 'Streptophyte_algae',
    'Hornworts', 'Liverworts', 'Mosses', 'Lycophytes', 'Monilophytes', 'Ferns', 
    'Cycadales', 'Ginkgoales', 'Gnetales', 'Gymnos', 'Gymnosperms', 'Pinales',
    'ANAGrade', 'Magnoliids', 'Monocots', 'Monocots/Alismatales', 'Monocots/Diosales',
    'Monocots/Pandanales', 'Monocots/Liliales', 'Monocots/Asparagales', 
    'Monocots/Commelinids', 'Eudicots/Ranunculales', 'Eudicots', 
    'Eudicots/Caryophyllales', 'Eudicots/Asterids', 'Eudicots/Santalales',
    'Eudicots/Saxifragales', 'Eudicots/Rosids'
]

COMPARTMENTS = [
    'Cytoplasm', 'Guard cell cytoplasm', 'Nucleus', 'Plastid', 
    'Mitochondria', 'Peroxisome'
]

COMPARTMENT_ABBREV = {
    'Cytoplasm': 'C', 
    'Nucleus': 'N', 
    'Plastid': 'P', 
    'Mitochondria': 'M', 
    'Peroxisome': 'Px'
}

AHL_COLORS = {
    'AHLa': '#2a788e',
    'AHLb': '#22a884',
    'Basal_AHL': '#414487',
    'Algal_AHL': '#7ad151'
}

# Set global plotting style
plt.style.use("seaborn-v0_8-whitegrid")
sns.set_context("notebook", font_scale=1.2)

def create_subplot_grid(n_clades, max_cols=3):
    """
    Creates a subplot grid based on number of clades.
    
    Args:
        n_clades: Number of clades to display
        max_cols: Maximum number of columns in grid
    
    Returns:
        tuple: (figure, axes, n_rows, n_cols)
    """
    n_cols = min(max_cols, n_clades)
    n_rows = (n_clades + n_cols - 1) // n_cols
    
    fig = plt.figure(figsize=(24, 7*n_rows))
    gs = fig.add_gridspec(n_rows, n_cols, hspace=0.8, wspace=0.8)
    axes = gs.subplots()
    
    # Reshape axes for consistent indexing
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    elif n_cols == 1:
        axes = axes.reshape(-1, 1)
        
    return fig, axes, n_rows, n_cols

def filter_taxa_by_samples(df, clade, min_samples):
    """
    Filters taxa based on minimum sample count requirement.
    
    Args:
        df: Input DataFrame
        clade: Clade to filter
        min_samples: Minimum number of samples required
    
    Returns:
        tuple: (filtered DataFrame, valid taxa)
    """
    clade_df = df[df['clade'] == clade]
    taxa_counts = clade_df['taxa'].value_counts()
    valid_taxa = taxa_counts[taxa_counts >= min_samples].index
    return clade_df[clade_df['taxa'].isin(valid_taxa)], valid_taxa

def create_small_multiples(df, min_samples=0):
    """
    Creates small multiple heatmaps showing localization patterns for each clade,
    aggregated by taxa, with improved layout and custom taxa order.
    Only includes taxa with sample counts >= min_samples.
    Displays sample counts in each cell.
    
    Args:
        df: Input DataFrame
        min_samples: Minimum number of samples required for a taxon to be included
    """
    clades = df['clade'].unique()
    fig, axes, n_rows, n_cols = create_subplot_grid(len(clades))
    
    for idx, clade in enumerate(clades):
        row, col = idx // n_cols, idx % n_cols
        clade_df, valid_taxa = filter_taxa_by_samples(df, clade, min_samples)
        
        if len(valid_taxa) == 0:
            axes[row, col].text(0.5, 0.5, f'No taxa in {clade} have ≥{min_samples} samples',
                              ha='center', va='center')
            axes[row, col].axis('off')
            continue
            
        # Get ordered taxa present in this clade
        ordered_taxa = [taxa for taxa in TAXA_ORDER if taxa in valid_taxa]
        
        # Calculate proportions and counts
        taxa_proportions = []
        taxa_counts_matrix = []
        
        for compartment in COMPARTMENTS:
            compartment_proportions = []
            compartment_counts = []
            
            for taxon in ordered_taxa:
                taxon_data = clade_df[clade_df['taxa'] == taxon][compartment]
                yes_count = (taxon_data == 'Yes').sum()
                no_count = (taxon_data == 'No').sum()
                total = yes_count + no_count
                
                proportion = (yes_count - no_count) / total if total > 0 else 0
                compartment_proportions.append(proportion)
                compartment_counts.append(total)
            
            taxa_proportions.append(compartment_proportions)
            taxa_counts_matrix.append(compartment_counts)
        
        # Create labels with sample counts
        taxa_labels = [
            f'{taxon} (n={len(clade_df[clade_df["taxa"] == taxon])})'
            for taxon in ordered_taxa
        ]
        
        # Create and format heatmap
        heatmap_data = np.array(taxa_proportions)
        create_annotated_heatmap(
            axes[row, col],
            heatmap_data,
            taxa_counts_matrix,
            COMPARTMENTS,
            taxa_labels,
            f'Clade: {clade}'
        )
    
    # Remove empty subplots
    remove_empty_subplots(fig, axes, len(clades), n_rows, n_cols)
    
    # Add title and adjust layout
    fig.suptitle(
        'Proportion of sequences in a particular taxa at a particular pH that have a predicted localisation that would likely complement AtSAL1 activity',
        fontsize=16,
        fontweight='bold',
        y=1.02
    )
    plt.tight_layout()
    return fig

def create_annotated_heatmap(ax, heatmap_data, count_matrix, row_labels, col_labels, title):
    """
    Creates an annotated heatmap with counts in cells.
    
    Args:
        ax: Matplotlib axis
        heatmap_data: Data for heatmap colors
        count_matrix: Matrix of counts to display
        row_labels: Labels for rows
        col_labels: Labels for columns
        title: Title for heatmap
    """
    sns.heatmap(
        heatmap_data,
        cmap='PRGn',
        vmin=-1,
        vmax=1,
        center=0,
        yticklabels=row_labels,
        xticklabels=col_labels,
        ax=ax,
        cbar_kws={'label': 'Proportion pH-localisation match/mismatch'},
        linewidths=0.5,
        linecolor='white'
    )
    
    # Add count annotations
    for i in range(len(row_labels)):
        for j in range(len(col_labels)):
            value = heatmap_data[i, j]
            text_color = 'white' if abs(value) > 0.5 else 'black'
            ax.text(
                j + 0.5,
                i + 0.5,
                str(count_matrix[i][j]),
                ha='center',
                va='center',
                color=text_color,
                fontsize=10
            )
    
    # Format axis
    format_heatmap_axis(ax, title)

def format_heatmap_axis(ax, title):
    """
    Formats the appearance of a heatmap axis.
    
    Args:
        ax: Matplotlib axis to format
        title: Title for the axis
    """
    ax.set_title(title, pad=20, fontsize=14, fontweight='bold')
    ax.set_xlabel('Taxa', labelpad=15)
    ax.set_ylabel('Subcellular Compartments', labelpad=15)
    
    ax.set_xticklabels(
        ax.get_xticklabels(),
        rotation=45,
        horizontalalignment='right',
        fontsize=14
    )
    
    ax.set_yticklabels(
        ax.get_yticklabels(),
        fontsize=13
    )
    
    ax.grid(False)

def remove_empty_subplots(fig, axes, n_plots, n_rows, n_cols):
    """
    Removes empty subplots from the figure.
    
    Args:
        fig: Matplotlib figure
        axes: Array of axes
        n_plots: Number of actual plots
        n_rows: Number of rows in grid
        n_cols: Number of columns in grid
    """
    if n_rows * n_cols > n_plots:
        for idx in range(n_plots, n_rows * n_cols):
            fig.delaxes(axes[idx // n_cols, idx % n_cols])

def pH_heatmap(df, min_samples=0):
    """
    Creates small multiple heatmaps showing the proportion of non-NaN values
    for each combination of taxa and pH level within each clade.
    Only includes taxa with sample counts >= min_samples.
    Displays total sample count in each cell.
    
    Args:
        df: Input DataFrame
        min_samples: Minimum number of samples required for a taxon to be included
    """
    clades = df['clade'].unique()
    pH_levels = ['7_0', '7_3', '7_6', '7_9', '8_2']
    pH_labels = [ph.replace('_', '.') for ph in pH_levels]
    
    fig, axes, n_rows, n_cols = create_subplot_grid(len(clades))
    
    for idx, clade in enumerate(clades):
        row, col = idx // n_cols, idx % n_cols
        clade_df, valid_taxa = filter_taxa_by_samples(df, clade, min_samples)
        
        if len(valid_taxa) == 0:
            axes[row, col].text(0.5, 0.5, f'No taxa in {clade} have ≥{min_samples} samples',
                              ha='center', va='center')
            axes[row, col].axis('off')
            continue
        
        # Get ordered taxa and create matrices
        ordered_taxa = [taxa for taxa in TAXA_ORDER if taxa in valid_taxa]
        heatmap_data = np.zeros((len(pH_levels), len(ordered_taxa)))
        count_matrix = np.zeros((len(pH_levels), len(ordered_taxa)), dtype=int)
        
        # Calculate proportions and counts
        for i, pH in enumerate(pH_levels):
            for j, taxon in enumerate(ordered_taxa):
                taxon_data = clade_df[clade_df['taxa'] == taxon]
                total_rows = len(taxon_data)
                if total_rows > 0:
                    non_nan_count = taxon_data[pH].notna().sum()
                    heatmap_data[i, j] = non_nan_count / total_rows
                    count_matrix[i, j] = non_nan_count
        
        # Create labels and heatmap
        taxa_labels = [f'{taxon} (n={len(clade_df[clade_df["taxa"] == taxon])})' 
                      for taxon in ordered_taxa]
        
        create_ph_heatmap(
            axes[row, col],
            heatmap_data,
            count_matrix,
            pH_labels,
            taxa_labels,
            f'Clade: {clade}'
        )
    
    # Remove empty subplots and add title
    remove_empty_subplots(fig, axes, len(clades), n_rows, n_cols)
    fig.suptitle(
        'Proportion of sequences in a particular taxa that have a value at a particular pH, likely to complement AtSAL1',
        fontsize=16,
        fontweight='bold',
        y=1.02
    )
    
    plt.tight_layout()
    return fig

def create_ph_heatmap(ax, heatmap_data, count_matrix, row_labels, col_labels, title):
    """
    Creates a pH-specific heatmap with annotations.
    
    Args:
        ax: Matplotlib axis
        heatmap_data: Data for heatmap colors
        count_matrix: Matrix of counts to display
        row_labels: Labels for rows
        col_labels: Labels for columns
        title: Title for heatmap
    """
    sns.heatmap(
        heatmap_data,
        cmap='YlOrRd',
        vmin=0,
        vmax=1,
        yticklabels=row_labels,
        xticklabels=col_labels,
        ax=ax,
        cbar_kws={'label': 'Proportion of complementing sequences'},
        linewidths=0.5,
        linecolor='white'
    )
    
    # Add count annotations
    for i in range(len(row_labels)):
        for j in range(len(col_labels)):
            value = heatmap_data[i, j]
            text_color = 'white' if value > 0.7 else 'black'
            ax.text(
                j + 0.5,
                i + 0.5,
                str(count_matrix[i, j]),
                ha='center',
                va='center',
                color=text_color,
                fontsize=10
            )
    
    # Format axis
    format_heatmap_axis(ax, title)

def create_upset_plot(df):
    """
    Creates an UpSet-style plot with statistical analysis including:
    - Compartment percentages per clade
    - Jaccard similarity indices between clades
    Shows both matching and mismatching predictions side by side
    """
    def calculate_compartment_percentages(sub_df):
        """Calculate percentage of proteins in each compartment."""
        total_proteins = len(sub_df)
        return {comp: (len(sub_df[sub_df[comp] == 'Yes']) / total_proteins) * 100 
                for comp in COMPARTMENTS}

    def calculate_jaccard_similarity(df1, df2):
        """Calculate Jaccard similarity between two sets of compartment combinations."""
        def get_combinations(df):
            combinations = set()
            for _, row in df.iterrows():
                current = frozenset(comp for comp in COMPARTMENTS 
                                  if pd.notna(row[comp]) and row[comp] == 'Yes')
                if current:
                    combinations.add(current)
            return combinations
        
        set1 = get_combinations(df1)
        set2 = get_combinations(df2)
        intersection = len(set1.intersection(set2))
        union = len(set1.union(set2))
        return intersection / union if union > 0 else 0

    def get_combinations_and_counts(sub_df, data_type='Yes'):
        """Get unique combinations of compartments and their counts."""
        all_combinations = []
        counts = {}
        clade_counts = {}
        
        for _, row in sub_df.iterrows():
            current_locations = set()
            valid_row = False
            
            for comp in COMPARTMENTS:
                if pd.notna(row[comp]) and row[comp] == data_type:
                    current_locations.add(comp)
                    valid_row = True
            
            if valid_row:
                current_set = frozenset(current_locations)
                counts[current_set] = counts.get(current_set, 0) + 1
                
                if current_set not in all_combinations:
                    all_combinations.append(current_set)
                    clade_counts[current_set] = {clade: 0 for clade in sub_df['clade'].unique()}
                clade_counts[current_set][row['clade']] += 1
        
        all_combinations.sort(key=lambda x: counts[x], reverse=True)
        return all_combinations, counts, clade_counts

    def create_subplot_pair(data, gs_section, title):
        """Create a pair of subplots for matching and mismatching predictions."""
        for data_type, subplot_gs in zip(['Yes', 'No'], [gs_section[0], gs_section[1]]):
            combinations, counts, clade_counts = get_combinations_and_counts(data, data_type)
            
            if not combinations:
                continue
            
            nested_gs = subplot_gs.subgridspec(2, 1, height_ratios=[1.5, 1], hspace=0.05)
            ax_sets = fig.add_subplot(nested_gs[0])
            
            bar_width = 0.8
            x_positions = np.arange(len(combinations))
            
            if title in ['CNP', 'SAL']:
                color = '#440154' if title == 'CNP' else '#fde725'
                heights = [counts[comb] for comb in combinations]
                ax_sets.bar(x_positions, heights, color=color, width=bar_width)
                max_bar_height[0] = max(max_bar_height[0], max(heights) if heights else 0)
            else:
                bottom = np.zeros(len(combinations))
                total_heights = np.zeros(len(combinations))
                for clade in sorted(data['clade'].unique()):
                    heights = [clade_counts[comb][clade] for comb in combinations]
                    ax_sets.bar(x_positions, heights, bottom=bottom, 
                              label=clade, color=AHL_COLORS[clade], width=bar_width)
                    bottom += heights
                    total_heights += heights
                max_bar_height[0] = max(max_bar_height[0], max(total_heights) if len(total_heights) > 0 else 0)
            
            # Format bar plot
            plot_title = f'{title} Prediction {"matching" if data_type == "Yes" else "mismatched to"} pH'
            format_bar_plot(ax_sets, plot_title, title)
            
            # Create matrix plot
            ax_matrix = fig.add_subplot(nested_gs[1])
            create_matrix_plot(ax_matrix, combinations, x_positions, bar_width)

    # Separate dataframes for each group
    df_cnp = df[df['clade'] == 'CNP']
    df_sal = df[df['clade'] == 'SAL']
    df_ahl = df[df['clade'].isin(['AHLa', 'AHLb', 'Basal_AHL', 'Algal_AHL'])]
    
    # Create figure
    fig = plt.figure(figsize=(15, 28))
    gs = fig.add_gridspec(4, 1, height_ratios=[2, 2, 2, 2], hspace=0.4)
    gs_plots = [gs[i].subgridspec(1, 2, width_ratios=[1, 1], wspace=0.3) for i in range(3)]
    
    # Track maximum bar height
    max_bar_height = [0]
    
    # Create plots
    create_subplot_pair(df_sal, gs_plots[0], 'SAL')
    create_subplot_pair(df_ahl, gs_plots[1], 'AHL Clades')
    create_subplot_pair(df_cnp, gs_plots[2], 'CNP')
    
    # Add statistics
    add_statistics(fig, gs[3], df_cnp, df_sal, df_ahl)
    
    # Add legend
    legend_text = "Legend: " + ", ".join(f"{COMPARTMENT_ABBREV[comp]}={comp}" 
                                       for comp in COMPARTMENTS)
    fig.text(0.1, 0.01, legend_text, fontsize=8)
    
    plt.subplots_adjust(bottom=0.15, right=0.85, left=0.15)
    return fig

def visualize_protein_localizations(df, min_samples=0):
    """
    Creates all visualizations for the protein localization data
    
    Args:
        df: Input DataFrame with protein localization data
        min_samples: Minimum number of samples required for including a taxon
    """
    # Create pH heatmap
    print("\nCreating complementation plot...")
    pH_multiples = pH_heatmap(df, min_samples=min_samples)
    pH_multiples.show()       
    print('This plot shows what proportion of sequences in a particular taxa have a value at a particular pH')
    
    # Create small multiples
    print("\nCreating pH-localisation plot...")
    small_multiples = create_small_multiples(df, min_samples=min_samples)
    small_multiples.show()
    print('This plot shows what proportion of sequences in a particular taxa at a particular pH have a predicted localisation that would likely complement AtSAL1 activity')
    
    # Create upset plot
    print("\nCreating upset plot...")
    upset_plot = create_upset_plot(df)
    upset_plot.show()

In [None]:
def process_cellular_locations(df):
    """
    Add columns for cellular locations and determine Yes/No based on localization and prediction values.
    Includes special handling for guard cell cytoplasm and plastid pH-dependent locations.
    Uses specific pH values for plastid localizations: 7.0-7.3 for dark, 7.9-8.2 for light.
    Args:
        df: DataFrame containing 'localisation' column and prediction columns
    Returns:
        DataFrame with added location and accuracy columns
    """
    # Dictionary mapping locations to their corresponding prediction columns
    loc_dict = {
        'Cytoplasm': ['7_0', '7_3'],
        'Nucleus': ['7_0', '7_3'],
        'Plastid stroma/lumen dark': ['7_0', '7_3'],
        'Plastid lumen light': ['7_9', '8_2'],
        'Mitochondria': ['7_9', '8_2'], # This is for the mitochondrial matrix
        'Peroxisome': ['8_2']
    }

    def check_location(row, location):
        # Special handling for Guard cell cytoplasm
        if location == 'Guard cell cytoplasm':
            if not pd.isna(row['localisation']) and 'Cytoplasm' in row['localisation'].split('|'):
                return 'Yes' if row['7_6'] != '0' and not pd.isna(row['7_6']) else 'No'
            return pd.NA

        # Special handling for Plastid locations
        if location in ['Plastid stroma/lumen dark', 'Plastid lumen light']:
            if not pd.isna(row['localisation']) and 'Plastid' in row['localisation'].split('|'):
                # Get the relevant pH columns for this location
                columns_to_check = loc_dict[location]
                
                # Check if any values in the relevant pH range are non-zero
                has_valid_ph = any(row[col] != '0' and not pd.isna(row[col]) for col in columns_to_check)
                
                if has_valid_ph:
                    return 'Yes'
                return 'No'
            return pd.NA

        # Regular location handling
        if pd.isna(row['localisation']) or location not in row['localisation'].split('|'):
            return pd.NA

        # Check if any corresponding prediction columns have valid values
        columns_to_check = loc_dict[location]
        if any(row[col] != '0' and not pd.isna(row[col]) for col in columns_to_check):
            return 'Yes'
        return 'No'

    # Add a column for each location, including Guard cell cytoplasm
    all_locations = list(loc_dict.keys()) + ['Guard cell cytoplasm']
    for location in all_locations:
        df[location] = df.apply(lambda row: check_location(row, location), axis=1)

    def determine_accuracy(row):
        # Get all location values, excluding NaN
        location_cols = ['Cytoplasm', 'Guard cell cytoplasm', 'Nucleus', 'Plastid stroma/lumen dark', 'Plastid lumen light', 'Mitochondria', 'Peroxisome']
        values = [row[col] for col in location_cols if not pd.isna(row[col])]
        
        # If no non-NaN values, return NaN
        if not values:
            return pd.NA
        
        has_yes = 'Yes' in values
        has_no = 'No' in values
        
        if has_yes and has_no:
            return 'Partial'
        elif has_yes and not has_no:
            return 'Yes'
        else: # has_no and not has_yes
            return 'No'

    # Add accuracy column
    df['accuracy'] = df.apply(determine_accuracy, axis=1)
    
    return df