In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.gridspec as gridspec
from matplotlib.ticker import MultipleLocator, FuncFormatter
import seaborn as sns
from scipy.ndimage import gaussian_filter1d
from scipy.optimize import linear_sum_assignment
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances, silhouette_score
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler,normalize
from sklearn.impute import KNNImputer
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score, pairwise_distances
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.manifold import TSNE
import umap
import igraph as ig
import leidenalg as la
import warnings
warnings.filterwarnings('ignore')

from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D 

pd.set_option("display.max_columns", None)  # None = no limit for when i do df.head()

#### Opening the text document that contains the columns description

In [None]:
with open("../data/gc_peak2_all_colnames.txt", "r") as f:
    description_colnames = f.read()

print(description_colnames)  

#### Hard coding variables useful :
- treatment_group list
- hex_colors list

In [None]:
##LIST CONTAINITNG THE GROUPS NAMES 
treatment_groups = [
        'BT474_mV_72hNHWD',
        'BT474_mV_72hSTC15',
        'BT474_mV_high_Untreated',
        'BT474_mV_low_Untreated',
        'BT474_mV_Untreated_Unsorted']

## ARBITRARY CHOSEN COLORS FOR THE CLUSTERS (THEY NEED TO HAVE A HIGH VALUE AND CHROMA TO STAND OUT ON A GREY BACKGROUND)
hex_colors = [
    "#0e67a7", "#ff7f0e", "#a0e468", "#d62728", "#9467bd",
    "#672417", "#e377c2", "#f5f523", "#28e0f5", "#3214a8",
    "#ca9d16", "#04a887", "#8c564b", "#17becf", "#bcbd22",
    "#2ca02c", "#1f77b4", "#ff9896", "#c5b0d5", "#98df8a",
    "#ffbb78", "#aec7e8", "#7f7f7f", "#c49c94", "#dbdb8d",
    "#9edae5"
]

##### Opening the datafile:
Calling the dataframe df

In [None]:
df = pd.read_csv("../data/gc_peak2_all.txt", delim_whitespace=True)

print(f'The shape of the dataframe is of {df.shape}')
#print(df.dtypes)
#print(df.columns)

df.head()

In [None]:
df_bulk = pd.read_csv("../data/camille_average_profile_pol2_by_cluster.txt", delim_whitespace=True)

print(f'The shape of the dataframe is of {df.shape}')
#print(df.dtypes)
#print(df.columns)
is_constant = df_bulk.groupby('cluster')['cov'].nunique().eq(1)

print("cov constant per cluster:", is_constant.all())

df_bulk.head(50)

Making a short function to manually calculate pol2 position, in case the column 'motif_center' has wrong values

In [None]:
#MANUALLY RECALCULATING POL2 POSITION AS THE MIDPOINT BETWEEN MOTIF_START AND MOTIF_END
def calculate_pol2_position(df):
    """Calculate the Pol II position as the midpoint between motif_start and motif_end."""
    if 'motif_start' in df.columns and 'motif_end' in df.columns:
        df['pol2_pos']= (df['motif_start'] + df['motif_end']) / 2
    else:
        raise ValueError("DataFrame must contain 'motif_start' and 'motif_end' columns.")
    return df

##### Preprocessing
Here are the following functions defined:
1. 'preprocess_dataframe'
2. 'filter_reads_per_gene_middle_bin_name'
3. 'bin_then_matrix'
4. 'process_into_matrix'
5. 'handling_NaN'
6. 'preprocess_long_for_plot'
7. 'plot_reads_long'
8. 'get_genes_list'

In [None]:
## 1. PREPROCESS OF THE DATAFRAME FUNCTION

def preprocess_dataframe(
    df: pd.DataFrame,
    #nan_threshold: float = 0.9,
    drop_columns: list = None,
) -> pd.DataFrame:
    
    """
    Preprocess a DataFrame by:
    # Dropping columns with too many NaN values
    0. Creating the 'pol2_pos' column as the midpoint between 'motif_start' and 'motif_end' (ignorinng 'motif_center' column, in case error in dataset)
    1. Removing user-defined non-useful columns
    2. Converting 'chr' column to numeric
    3. Converting the variable 'pol2' into a binary marker: 1 for pol2 and 0 for no pol2
    4. Creating a unique 'gene_tss' identifier
    5. Dropping duplicate rows
    
    Parameters
    ----------
    df : pd.DataFrame
        Input DataFrame to preprocess.
    #nan_threshold : float, optional (default=0.9) (each column that has more than 10% of NaN is dropped)
        Minimum fraction of non-NaN values required to keep a column.
    drop_columns : list, optional
        Additional columns to drop.
    
    Returns
    -------
    pd.DataFrame
        Cleaned and preprocessed DataFrame.
    """

    # Work on a copy to avoid modifying original
    df_proc = df.copy().drop_duplicates()

    # # Drop columns with too many NaN values
    # coverage_mask = df_proc.notna().mean(axis=0) > nan_threshold    #dropping columns that have too many NaN
    # dropped_columns = df_proc.columns[~coverage_mask].tolist()      #lst of dropped columns
    # df_proc = df_proc.loc[:, coverage_mask]                         #new df only keeping useful columns
    # print(f"Filtered from {df.shape[1]} to {df_proc.shape[1]} regions") #tells how many columns dropped
    # print("Dropped (NaN coverage):", dropped_columns)                   #lists the dropped columns

    # Create 'pol2_pos' column as midpoint between 'motif_start' and 'motif_end'
    df_proc = calculate_pol2_position(df_proc)

    # Drop user-deemed non-useful columns
    if drop_columns:
        df_proc = df_proc.drop(columns=[col for col in drop_columns if col in df_proc.columns], errors="ignore")

    # Convert chr column to numeric (strip "chr")
    if "chr" in df_proc.columns:
        df_proc["chr"] = pd.to_numeric(df_proc["chr"].str.replace("chr", "", regex=False), errors="coerce")

    #Convert 'pol2' label into a binary marker (change pol2 into 1)
    if 'pol2' in df_proc.columns:
        mapping_pol2 = {"pol2":1, "nonpol2":0}
        df_proc["pol2"] = df_proc["pol2"].map(mapping_pol2)

    # Create gene_tss column if columns exist
    if "gene" in df_proc.columns and "tss_pos" in df_proc.columns:
        df_proc["gene_tss"] = df_proc["gene"].astype(str) + "_" + df_proc["tss_pos"].astype(str)

    # Drop duplicates again after transformations
    df_proc = df_proc.drop_duplicates()

    print(f"Final shape: {df_proc.shape}")
    print(f"NUmber of gene_tss: {df_proc['gene_tss'].nunique() if 'gene_tss' in df_proc.columns else 'N/A'}")

    return df_proc


In [None]:
## 2. FUNCTION TO FILTER ONLY THE GENES THAT HAVE SUFFICIENT READS THAT OVERLAP ON THE SAME GENOMIC REGION, CENTERED ON THE MIDDLE
## THE INPUT DF MUST BE BINNED

def filter_reads_per_gene_middle_bin_name(
    df: pd.DataFrame, #usually need to put the preprocessed dataframe BUT BINNED, with all the genes in there
    middle_bin_name: str =None,
    min_reads: int =50, # the minimal number of reads that we want
    min_bins: int =40,  # the miniml lenght per bin that we want the read to be
    require_middle_bin: bool =True
    
):
    """
    Filters reads per gene based on coverage in consecutive bins and keeps intervals
    containing a specified middle bin column.


    Parameters
    ----------
    df : pd.DataFrame
        DataFrame with reads as rows, bins as columns, and a gene identifier column or index.
    middle_bin_name : str
        Name of the middle bin column to require in the kept intervals.
    min_reads : int
        Minimum number of overlapping reads per bin.
    min_bins : int
        Minimum number of consecutive bins required.
    require_middle_bin : bool
        If True, only keeps regions that contain the middle bin.

    Returns
    -------
    df_filtered : pd.DataFrame
        Reads from genes that pass the filter, with original indices preserved.
        THE DF IS BINNED AND IN MATRIX FORM WITH THE NAN VALUES
        IT IS GROUPED BY GENE, had had as index; 'gene-tss', 'readid', 'group' 'cluster
    filtered_regions : dict
        Per-gene list of intervals (start_bin, end_bin) that pass the filter.
    """
    
    # Detect bin columns
    if 'readid' in df.columns:   
        bin_cols = df.columns.difference(['readid']) #it creates a dataframe that keeps every other column of df except the columns 'readid'
        #it's just to take th bin columns, not the 'readid' column
    
    else:
        bin_cols = df.columns
    
    # Get index of middle bin column if specified
    if middle_bin_name is not None:
        if middle_bin_name not in bin_cols:
            raise ValueError(f"Middle bin '{middle_bin_name}' not found in bin columns")
        middle_bin_idx = bin_cols.get_loc(middle_bin_name)
    
    def find_high_coverage_regions(
            coverage, 
            min_reads, 
            min_bins
            ): 
        """
        Identifies continuous genomic regions (bins) where the coverage (number of overlapping reads) 
        is consistently above a minimum threshold, for at least a minimum region length.

        Parameters:
        -----
        coverage: list or np.ndarray
            A sequence of coverage values (e.g., per genomic bin). Each value indicates how many reads cover that bin.
        min_reads: int)
            The minimum number of reads required for a bin to be considered "covered."
        min_bins: int
            The minimum number of consecutive bins that must satisfy the coverage threshold in order to define a valid region.

        Returns:
        -----

        regions: list of tuples (start, end)
            Each tuple (start, end) represents the indices of bins forming a valid high-coverage region.
        """
        regions = []
        start = None
        for i, cov in enumerate(coverage):
            if cov >= min_reads:
                if start is None:
                    start = i
            else:
                if start is not None:
                    end = i - 1
                    if (end - start + 1) >= min_bins:
                        regions.append((start, end))
                    start = None
        if start is not None:
            end = len(coverage) - 1
            if (end - start + 1) >= min_bins:
                regions.append((start, end))
        return regions
    
    filtered_rows = []
    filtered_regions = {}

    # Group by gene
    groups = df.groupby(level='gene_tss', observed=True)
    
    for gene, sub in groups:
        coverage = sub[bin_cols].notna().sum(axis=0)
        regions = find_high_coverage_regions(coverage, min_reads, min_bins)
        
        # Keep only intervals covering the middle bin if required
        if require_middle_bin and middle_bin_name is not None:
            regions = [(s,e) for s,e in regions if s <= middle_bin_idx <= e]
        
        if not regions:
            continue  # skip gene if no valid region
        
        filtered_regions[gene] = regions
        
        # Keep reads overlapping at least one valid region
        mask = np.zeros(sub.shape[0], dtype=bool)
        for start, end in regions:
            mask |= sub[bin_cols].iloc[:, start:end+1].notna().any(axis=1)
        
        filtered_rows.append(sub[mask])
    
    if not filtered_rows:
        return pd.DataFrame(), {}
    
    df_filtered = pd.concat(filtered_rows, ignore_index=False)

    
    return df_filtered, filtered_regions


In [None]:
## 3. BINNING THE DATAFRAME AND TURNING IT INTO MATRIX FORM
# TO BE CALLED IN A FUNCITON, NO USE ALONE

def bin_then_matrix(
        df : pd.DataFrame ,
        indexes : list = ['gene_tss','readid','group','cluster'], 
        bin_size : int =50,
):
        """
        Bins the reads, and then transforms the dataframe into a dataframe matrix with methylation values per bin

        Parameters:
        df : pd.dataframe
                the dataframe must have 'readid','gene_tss','group','cluster','meth', 'C_pos' as minimal columns
        
        indexes: list
                by default the following columns turn into indexes ['gene_tss','readid','group','cluster']
        
        bin_size : int
                by default 50 bp
        
        Returns:
        df_matrix: dataframe
                a matric that still is a dataframe, having as indexes gene_tss, readid, group, cluster, 
                and the bins as columns, and as value the mean methylation value per bin
        """

        #Creating a new 'bin' column:
        if 'C_pos' in df.columns:
                df['bin'] = (df['C_pos']//bin_size)*bin_size


        #Pivot the dataframe into a matrix (but it's still a dataframe)
        if 'readid' and 'group' in df.columns:
                df = df.pivot_table(index=indexes, columns='bin', values='meth', aggfunc='mean')
        
        return df


In [None]:
## 4. PROCESSING -- MARKING THE METHYLATION OF READS BY BINNED REGIONS -- GIVING OUT A DATAFRAME

def process_into_matrix(
        df: pd.DataFrame,
        gene: object=None,          #the id taken from the column gene_tss, only if we want to process per gene
        bin_size: int = 50,         #bin size, by default 50bp
):

    """
    Process the DataFrame to turn it into a binned methylation value matrix:
    CAN EITHER BE USED ON THE WHOLE DATAFRAME WITH ALL THE GENES, OR ON A SINGLE GENE

    1. Binning on the GpC site position, relative to the Pol2 summit, by default bin size = 50bp
    2. Making a dataframe 'matrix' having as indexes: 'readid','group' and 'cluster' if processed per gene, adding 'gene_tss' if processed entirely
       as colums: 'bin', and as values: 'meth' (mean methylation value per bin)
    
    Parameters
    ----------
    df : pd.DataFrame
        Input preprocessed dataframe 
    gene : object 
        Name of the gene, that we want to isolate and work on, BY DEFAULT SSUMING THAT WE WORK ON THE WHOLE DATAFRAME
    bin_size : integer
        By default set as 50 bp
    
    Returns
    -------
     
    pd.DataFrame
        with missing values, not yet a standardized matrix, ready to be processed in a PCA, and for treatments.
    """

    if gene != None:
        #filtering out to get the new dataframe to work on
        df=df[df['gene_tss']== gene]

        #Binning and then turning into a methylation matrix
        df= bin_then_matrix(df, indexes= ['readid','group','cluster'], bin_size= bin_size)

    else:
        df = bin_then_matrix(df, indexes= ['gene_tss','readid','group','cluster'], bin_size= bin_size)

     
    print(f"Shape of the dataframe : {df.shape}")
    
    return df


In [None]:
## 5. TO HANDLE THE NAN VALUES ONCE IN MATRIX FORM

def handling_NaN(
        df: pd.DataFrame, # must be binned and in matrix form
        nan_threshold: float = 0.7, #limit to drop bins that have too many missing values
        nan_method: object = None 

):
        """
        Parameters:

        nan_threshold: float
        threshold for dropping bins that have too many missing values
        nan_method : string
        Choosing the method on how to deal with the missing methylation values, either by 'drop' or by 'impute', or eavint the NaN there. By default leaving them there

        Returns:
        df: the dataframe matrix with or without missing values depending on the nan method

        """
        #Dropping the reads that contain NaN values
        if nan_method == 'drop':
                
                # Drop bins with too many NaN values
                coverage_mask = df.notna().mean(axis=0) > nan_threshold   #dropping columns that have too many NaN
                dropped_columns = df.columns[~coverage_mask].tolist()      #lst of dropped columns
                df = df.loc[:, coverage_mask]                         #new df only keeping useful columns

                print("Dropped (NaN coverage):", dropped_columns)                   #lists the dropped columns

                i,j = df.shape
                #filter by dropping the reads that have too many NaN
                df = df.dropna(thresh=j)       # Keep rows with at least j non-NaN values (= dropping all rows that had a NaN)
        
        #Imputing the values with the kNN method
        elif nan_method == 'impute':
                
                # Drop bins with too many NaN values
                coverage_mask = df.notna().mean(axis=0) > nan_threshold   #dropping columns that have too many NaN
                dropped_columns = df.columns[~coverage_mask].tolist()      #lst of dropped columns
                df = df.loc[:, coverage_mask]                         #new df only keeping useful columns

                print("Dropped (NaN coverage):", dropped_columns)                   #lists the dropped columns

                imputer = KNNImputer(n_neighbors=5)
                X = imputer.fit_transform(df) #X_impute is a numpy array
                df.loc[:, :] = X        # Put results back into the same DataFrame

        #leaving the missing values otherwise:
        
        return df


In [None]:
## 6. FUNCTION TO PREPARE THE DATAFRAME IN LONG FORMAT FOR PLOTTING READS
def preprocess_long_for_plot(df, include_locus_cluster: bool = False):
    """
    Prepare dataframe in long format for plotting reads.
    Each row = one CpG per read.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe with columns at least ['gene_tss', 'group', 'cluster', 'pol2_pos', "readid", "C_start", "meth", 'locus_cluster'].
    include_locus_cluster : bool, optional
        If True, keep 'locus_cluster' column (for plotting after clustering).
        If False, ignore it.
    """
    df = df.reset_index(drop=True)  # ensure features are columns
    
    # Define base columns
    base_cols = ["gene_tss", "group", "readid", "cluster", "C_start", "meth",'pol2_pos']
    keep_cols = [c for c in df.columns if c in base_cols]

    # Optionally add locus_cluster if it exists
    if include_locus_cluster and "locus_cluster" in df.columns:
        keep_cols.append("locus_cluster")

    df = df[keep_cols]

    # Compute per-read start/end (no lists!)
    span = (
        df.groupby("readid")["C_start"]
        .agg(["min", "max"])
        .rename(columns={"min": "read_start", "max": "read_end"})
    )
    df = df.merge(span, on="readid", how="left")

    return df


In [None]:
## 7. FUNCTION TO PLOT THE READS IN LONG FORMAT    
def plot_reads_long(
    df,
    filters=None,
    facet_by=None,
    color_by=None,
    hex_colors=None,
    max_facets=14
):
    """
    Plot per-read methylation (long format, one row per CpG per read), centered on Pol2 position.
    Facet heights scale with number of reads per facet.

    Parameters
    ----------
    df : pd.DataFrame
        Long-format dataframe (one row per CpG per read).
        Must have 'read_start', 'read_end', and 'pol2_pos' columns.
    filters : dict, optional
        Column:value filters, e.g. {"gene_tss": "BRCA1", "group": "control"}.
    facet_by : str or list, optional
        Column(s) to facet subplots by (e.g. "group", ["gene_tss","group"]).
    color_by : str, optional
        Column to color read spans by (e.g. "cluster").
    hex_colors : list, optional
        List of hex colors to use for categories in `color_by`.
    max_facets : int
        Prevents generating too many subplots at once.
    """
    df = df.copy()

    # --- Filtering ---
    if filters:
        for col, val in filters.items():
            if col not in df.columns:
                raise ValueError(f"Column '{col}' not found in dataframe.")
            if isinstance(val, list):
                df = df[df[col].isin(val)]
            else:
                df = df[df[col] == val]

    if df.empty:
        raise ValueError("No reads left after filtering!")

    # --- Faceting ---
    if facet_by is None:
        facet_values = [("All", df)]
    else:
        if isinstance(facet_by, str):
            facet_by = [facet_by]
        facet_values = list(df.groupby(facet_by))
        if len(facet_values) > max_facets:
            raise ValueError(f"Too many facets ({len(facet_values)}). Max allowed: {max_facets}")

    # --- Color mapping ---
    color_map = None
    if color_by:
        if color_by not in df.columns:
            raise ValueError(f"Column '{color_by}' not found in dataframe.")
        categories = sorted(df[color_by].dropna().unique())
        if hex_colors is None:
            import matplotlib.cm as cm
            cmap = cm.get_cmap("tab20", len(categories))
            hex_colors = [cmap(i) for i in range(len(categories))]
        if len(hex_colors) < len(categories):
            raise ValueError(f"Not enough colors for {len(categories)} categories.")
        color_map = dict(zip(categories, hex_colors))

    # --- Figure setup with gridspec ---
    n_facets = len(facet_values)
    heights = [max(1, subdf["readid"].nunique() * 0.15) for _, subdf in facet_values]  # scale heights
    total_height = sum(heights) + 2  # add some padding
    fig = plt.figure(figsize=(30, total_height))
    gs = gridspec.GridSpec(n_facets, 1, height_ratios=heights)

    # --- Plot each facet ---
    for i, (facet_key, subdf) in enumerate(facet_values):
        ax = fig.add_subplot(gs[i, 0])
        ax.set_facecolor("#919191")  # light gray background

        if "pol2_pos" not in subdf.columns:
            raise ValueError("Column 'pol2_pos' not found for centering!")

        # Shift positions relative to pol2_pos
        subdf = subdf.copy()
        subdf["C_start_shifted"] = subdf["C_start"] - subdf["pol2_pos"]
        subdf["read_start_shifted"] = subdf["read_start"] - subdf["pol2_pos"]
        subdf["read_end_shifted"] = subdf["read_end"] - subdf["pol2_pos"]

        # Order reads by color_by if requested
        if color_by and color_by in subdf.columns:
            grouped_reads = subdf.groupby(color_by)["readid"].unique().to_dict()
            read_order = [(cat, rid) for cat, rids in grouped_reads.items() for rid in rids]
        else:
            read_order = [(None, rid) for rid in subdf["readid"].unique()]

        # Plot each read
        for idx, (cat_value, readid) in enumerate(read_order):
            sub = subdf[subdf["readid"] == readid]
            span_color = color_map.get(cat_value, "black") if color_by else "black"

            ax.hlines(
                idx,
                xmin=sub["read_start_shifted"].iloc[0],
                xmax=sub["read_end_shifted"].iloc[0],
                color=span_color,
                linewidth=1.6
            )

            # CpG dots
            for _, row in sub.iterrows():
                dot_color = "white" if row["meth"] == 1 else "black"
                ax.plot(row["C_start_shifted"], idx, "o", color=dot_color, markersize=4)

        # Vertical line at Pol2
        ax.axvline(x=0, color="#C80028", linestyle="-", linewidth=2, label="Pol2 position")

        # Labels, limits, formatting
        ax.set_xlabel("Position relative to Pol2 (bp)", fontsize=12)
        ax.set_ylabel("Read IDs", fontsize=12)
        x_min = subdf["read_start_shifted"].min() - 100
        x_max = subdf["read_end_shifted"].max() + 100
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(-2, len(read_order) + 2)

        # Facet title
        if facet_by:
            if isinstance(facet_key, tuple):
                title = ", ".join([f"{col}={val}" for col, val in zip(facet_by, facet_key)])
            else:
                title = f"{facet_by[0]}={facet_key}"
        else:
            title = "All Reads"
        ax.set_title(f"Read-level methylation centered on Pol2 ({title})", fontsize=16)
        ax.grid(True)
        ax.invert_yaxis()

    # --- Legend ---
    handles = [
        mlines.Line2D([], [], color="black", marker="o", linestyle="None", markersize=6, label="Unmethylated (0)"),
        mlines.Line2D([], [], color="black", marker="o", markerfacecolor="white", linestyle="None", markersize=6, label="Methylated (1)"),
        mlines.Line2D([], [], color="#C80028", linestyle="-", linewidth=2, label="Pol2 position")
    ]
    if color_map:
        for cat, col in color_map.items():
            handles.append(mlines.Line2D([], [], color=col, linewidth=2, label=f"{color_by}={cat}"))

    fig.legend(handles=handles, loc="upper right", fontsize=9, frameon=True)
    fig.tight_layout()
    plt.subplots_adjust(top=0.93, bottom=0.05)

    # return fig


In [None]:
## 8. SMALL FUNCTION TO GET THE LIST OF ALL THE GENES NAMES OF A DATAFRAME AND THE LIST OF ALL THE SUBDATAFRAME OF THESE GENES

def get_genes_list(df):
    gene_list_names= [] #the list that will take in the genes names
    gene_list_df=[] #the list that will take in the dataframes associated to the genes

    for gene, sub_df in df.groupby(df.index.get_level_values("gene_tss")):
        gene_list_names.append(gene)
        gene_list_df.append(sub_df)

    return gene_list_names,gene_list_df

##### Functions coded for the clustering:
- 'clustering_final'
- (unused) (trying to sweep parameters to get exact number of clusters)
- 'run_pipelines_on_genes' (to violin scatter plot the metrics)
- 'plot_violin_scatter'
- 'plot_compare_pipelines_grid'
- 'silhouettes_multi'
- 'alternate_clustering' ( and wit it 'filter_by_missingness', 'masked_pearson_correlationmatrix')
- 'plot_umap'

In [None]:
## 3. IMPROVED CLUSTERING ALGORITHM, FOLLOWING THE STEPS: PCA --> ADAPTIVE KNN GRAPH (WEIGHTED) --> LEIDEN
def clustering_final(
    df,
    n_neighbors=15,
    nan_threshold : float = 0.7, #to drop the bins that have too many nan
    nan_method : str = 'drop', #by default dropping the rows that contain nan, but can use 'impute' too
    scaling : bool = False, #whether to scale the data or not
    pca_or_not = True, #whether to do a pca or not
    n_pcs= None ,# if None: by default 0.95 variance, otherwise int number of pcs to keep if pca is done
    metric='cosine', #cosine or euclidean
    transform='none', # 'none', 'logit', or 'arcsine'
    kernel_type='laplacian', # 'laplacian' or 'gaussian'
    leiden_resolution=1.0,
    seed=42
):
    """
    Perform Leiden clustering on PCA-reduced data with adaptive similarity weighting.

    Parameters
    ----------
    df : pandas.DataFrame
        Input data (numeric).
    n_neighbors : int, default=15
        Number of neighbors for kNN graph construction.
    n_pcs : int, default=30
        Number of principal components to retain.
    metric : str, default='cosine'
        Distance metric for nearest neighbors.
    transform : {'none', 'logit', 'sqrt'}, default='none'
        Optional transformation applied to data before PCA.
    leiden_resolution : float, default=1.0
        Resolution parameter for the Leiden algorithm.
    seed : int, default=42
        Random seed for reproducibility.

    Returns
    -------
    clusters : np.ndarray
        Cluster assignments for each observation.
    part : leidenalg.Partition
        Leiden partition object.
    X_pca : np.ndarray
        PCA-transformed coordinates.
    """
    #Handling the NaN:

    df= handling_NaN(df, nan_threshold, nan_method)
    
    # --- 1) Convert to NumPy and handle NaNs ---
    X = df.to_numpy(dtype=float)

    # assert no NaNs remain if nan_method='drop'
    if np.isnan(X).any():
        raise ValueError("NaNs remain after handling_NaN; impute explicitly before PCA.")
    
    # --- 2) Optional transformation ---
    if transform == 'logit':
        Xc = np.clip(X, 1e-3, 1 - 1e-3)
        Xc = np.log(Xc / (1 - Xc))
    elif transform == 'arcsine':
        Xc = np.arcsin(np.sqrt(np.clip(X, 0.0, 1.0)))
    elif transform == 'none':
        Xc = X
    else:
        raise ValueError("transform must be 'none', 'logit', or 'arcsine'")

    # # # 3) Standardize features
    if scaling:
        scaler = StandardScaler(with_mean=True, with_std=True)
        Xc = scaler.fit_transform(Xc)

    # --- 3) PCA reduction ---
    if pca_or_not == True:
        if n_pcs is None:
            pca = PCA(n_components=0.95) # if i want to choose the number of components that keep 95% of the variance
        elif isinstance(n_pcs, int) and n_pcs > 0:
            n_pcs = min(n_pcs, min(Xc.shape) - 1)
        
        pca = PCA(n_components=n_pcs, svd_solver='auto', random_state=seed) #if i want to manually define pca components

        X_pca = pca.fit_transform(np.nan_to_num(Xc, nan=0.0))
        print("Number of components chosen:", pca.n_components_)
    else:
        X_pca = Xc  # Skip PCA 
        
    if metric == 'cosine' and not pca_or_not:
        X_pca = normalize(X_pca, norm='l2', axis=1) 
    N = X_pca.shape[0]

    # --- kNN graph ---
    nn = NearestNeighbors(n_neighbors=n_neighbors, metric=metric)
    nn.fit(X_pca)
    dist, idx = nn.kneighbors(X_pca, return_distance=True)  # dist, idx: (N, k)

    # --- Detect if self is included ---
    self_included = np.all(idx[:, 0] == np.arange(N))
    eps = 1e-12

    # --- Adaptive kernel construction ---
    if kernel_type == 'laplacian':
        # Per-node scale τ_i: median neighbor distance (exclude self if present)
        dist_for_scale = dist[:, 1:] if self_included else dist #if self is included, drop the first column
        tau = np.median(dist_for_scale, axis=1) + 1e-6  # (N,)

        
        # Similarity per neighbor (N, k), asymmetric (depends on i only)
        sim = np.exp(-dist / tau[:, None]) 

    elif kernel_type == 'gaussian':
        # Local scales σ_i = distance to k-th neighbor (last col)
        sigma = dist[:, -1] + eps  # (N,)

        # σ_j per neighbor via indexing
        sigma_j = sigma[idx]       # (N, k)
        sigma_i = sigma[:, None]   # (N, 1)

        # Symmetric adaptive Gaussian weights per i→j neighbor
        sim = np.exp(- (dist ** 2) / (sigma_i * sigma_j + eps))

    else:
        raise ValueError("kernel_type must be 'laplacian' or 'gaussian'")

    # --- Build undirected weighted graph (union kNN with max weight) ---
    edges = {}

    for i in range(N):
        if self_included:
            neigh_idx = idx[i, 1:]
            neigh_w = sim[i, 1:]
        else:
            neigh_idx = idx[i, :]
            neigh_w = sim[i, :]

        for j, w in zip(neigh_idx, neigh_w):
            if i == j or w <= 0:
                continue

            a, b = (i, j) if i < j else (j, i)
            edges[(a, b)] = max(edges.get((a, b), 0.0), float(w))

    # Build graph
    e_list = list(edges.keys())
    w_list = [edges[e] for e in e_list]

    g = ig.Graph(n=N, edges=e_list, directed=False)
    g.es["weight"] = w_list

    # Leiden
    part = la.find_partition(
        g,
        la.RBConfigurationVertexPartition,
        weights=g.es['weight'],
        resolution_parameter=leiden_resolution,
        seed=seed
    )

    clusters = np.array(part.membership)

    # 9) Quality metrics
    metrics = {}

    # Silhouette
    try:
        sil = silhouette_score(X_pca, clusters, metric=metric)
    except Exception:
        D = pairwise_distances(X_pca, metric=metric)
        sil = silhouette_score(D, clusters, metric='precomputed')

    metrics['silhouette'] = sil if len(np.unique(clusters)) > 1 else None

    # Calinski–Harabasz and Davies–Bouldin
    try:
        metrics['calinski_harabasz'] = (
            calinski_harabasz_score(X_pca, clusters)
            if len(np.unique(clusters)) > 1 else None
        )
    except Exception:
        metrics['calinski_harabasz'] = None

    try:
        metrics['davies_bouldin'] = (
            davies_bouldin_score(X_pca, clusters)
            if len(np.unique(clusters)) > 1 else None
        )
    except Exception:
        metrics['davies_bouldin'] = None

    # Leiden objective value
    try:
        metrics['leiden_quality'] = float(part.quality())
    except Exception:
        metrics['leiden_quality'] = None

    # Weighted modularity
    try:
        metrics['modularity'] = g.modularity(clusters.tolist(), weights=g.es['weight'])
    except Exception:
        metrics['modularity'] = None

    # Cluster sizes
    unique, counts = np.unique(clusters, return_counts=True)
    metrics['cluster_sizes'] = dict(zip(unique.tolist(), counts.tolist()))


    return df, X_pca, part, clusters, metrics


In [None]:
# ## SWEEP TO FIND THE OPTIMAL PARAMETERS TO GET 6 CLUSTERS

# # 1) Prepare embedding once (transform -> optional scale -> PCA)
# def prepare_embedding(
#     df,
#     nan_threshold=0.7,
#     nan_method='drop',         # your handling_NaN must implement this
#     transform='arcsine',       # 'none' | 'logit' | 'arcsine'
#     scaling=False,
#     pca_or_not=True,
#     n_pcs=None,                # None => 95% variance
#     seed=42
# ):
#     """Clean, transform, optionally scale, and perform PCA on data."""

#     df_used = handling_NaN(df, nan_threshold, nan_method)
#     X = df_used.to_numpy(float)

#     if np.isnan(X).any():
#         raise ValueError("NaNs remain after handling_NaN; impute before PCA.")

#     # --- Transform ---
#     if transform == 'logit':
#         X = np.clip(X, 1e-3, 1 - 1e-3)
#         X = np.log(X / (1 - X))
#     elif transform == 'arcsine':
#         X = np.arcsin(np.sqrt(np.clip(X, 0.0, 1.0)))
#     elif transform == 'none':
#         pass
#     else:
#         raise ValueError("transform must be 'none', 'logit', or 'arcsine'")

#     # --- Optional scaling ---
#     if scaling:
#         X = StandardScaler(with_mean=True, with_std=True).fit_transform(X)

#     # --- PCA ---
#     if pca_or_not:
#         if n_pcs is None:
#             pca = PCA(n_components=0.95, random_state=seed)
#         else:
#             n_pcs = min(int(n_pcs), min(X.shape) - 1)
#             pca = PCA(n_components=n_pcs, random_state=seed)

#         X_pca = pca.fit_transform(X)
#         n_comp = pca.n_components_
#     else:
#         X_pca = X
#         n_comp = X_pca.shape[1]

#     return df_used, X_pca, n_comp


# # 2) Compute kNN once at max_k, slice for smaller k
# def kneighbors_upto_k(X, max_k, metric='euclidean'):
#     """Compute neighbors up to max_k and reuse slices for smaller k."""
#     nn = NearestNeighbors(n_neighbors=int(max_k), metric=metric)
#     nn.fit(X)
#     dist, idx = nn.kneighbors(X, return_distance=True)  # (N, max_k)

#     self_included = np.all(idx[:, 0] == np.arange(X.shape[0]))
#     return dist, idx, self_included


# # 3) Build graph from a slice of neighbors (vectorized)
# def graph_from_neighbors(dist_slice, idx_slice, self_included, kernel_type='laplacian'):
#     """Construct undirected weighted graph from neighbor distances."""
#     N, k = dist_slice.shape

#     # --- Adaptive scales ---
#     dist_for_scale = dist_slice[:, 1:] if self_included else dist_slice
#     tau = np.median(dist_for_scale, axis=1) + 1e-6

#     # --- Similarity (Laplacian kernel) ---
#     sim = np.exp(-dist_slice / tau[:, None])

#     # --- Vectorized edge construction: union kNN with max weight ---
#     I = np.repeat(np.arange(N), k)
#     J = idx_slice.ravel()
#     W = sim.ravel()

#     # Drop self edges
#     mask = I != J
#     I, J, W = I[mask], J[mask], W[mask]

#     # Undirected: keep i < j, combine duplicates with max weight
#     a = np.minimum(I, J)
#     b = np.maximum(I, J)

#     edf = pd.DataFrame({'a': a, 'b': b, 'w': W})
#     edf = edf.groupby(['a', 'b'], as_index=False)['w'].max()

#     g = ig.Graph(n=N, edges=list(zip(edf['a'], edf['b'])), directed=False)
#     g.es['weight'] = edf['w'].to_numpy()

#     return g


# # 4) Run Leiden and compute metrics (optionally sample for silhouette)
# def run_leiden_and_metrics(
#     g,
#     X_embed,
#     resolution,
#     metric='euclidean',
#     seed=42,
#     silhouette_sample=None
# ):
#     """Run Leiden clustering and compute evaluation metrics."""
#     part = la.find_partition(
#         g,
#         la.RBConfigurationVertexPartition,
#         weights=g.es['weight'],
#         resolution_parameter=float(resolution),
#         seed=int(seed)
#     )
#     labels = np.array(part.membership)

#     out = {
#         'part': part,
#         'clusters': labels,
#         'n_clusters': int(len(np.unique(labels))),
#         'leiden_quality': float(part.quality()),
#         'modularity': float(g.modularity(labels.tolist(), weights=g.es['weight']))
#     }

#     # --- Only compute metrics if >1 cluster ---
#     if out['n_clusters'] > 1:
#         X_eval, y_eval = X_embed, labels

#         if silhouette_sample is not None and X_embed.shape[0] > silhouette_sample:
#             rng = np.random.default_rng(seed)
#             idx = rng.choice(X_embed.shape[0], size=int(silhouette_sample), replace=False)
#             X_eval = X_embed[idx]
#             y_eval = labels[idx]

#         try:
#             out['silhouette'] = float(silhouette_score(X_eval, y_eval, metric=metric))
#         except Exception:
#             out['silhouette'] = None

#         try:
#             out['calinski_harabasz'] = float(calinski_harabasz_score(X_eval, y_eval))
#         except Exception:
#             out['calinski_harabasz'] = None

#         try:
#             out['davies_bouldin'] = float(davies_bouldin_score(X_eval, y_eval))
#         except Exception:
#             out['davies_bouldin'] = None

#     else:
#         out['silhouette'] = None
#         out['calinski_harabasz'] = None
#         out['davies_bouldin'] = None

#     return out


# # 5) Adaptive resolution search on a fixed graph (fast)
# def search_resolution_on_graph(
#     g,
#     X_embed,
#     target=6,
#     res_init=1.0,
#     res_min=0.02,
#     res_max=10.0,
#     up_factor=1.5,
#     max_iter=12,
#     metric='euclidean',
#     seed=42,
#     silhouette_sample=None
# ):
#     """Search for Leiden resolution yielding target cluster count."""
#     history = []
#     res = float(res_init)

#     for _ in range(int(max_iter)):
#         out = run_leiden_and_metrics(
#             g, X_embed, resolution=res,
#             metric=metric, seed=seed, silhouette_sample=silhouette_sample
#         )
#         out['resolution'] = res
#         history.append(out)

#         c = out['n_clusters']
#         if c == int(target):
#             break
#         if c < target:
#             res = min(res * up_factor, res_max)
#         else:
#             res = max(res / up_factor, res_min)

#     return history


# # 6) Full sweep over k (reusing PCA and kNN)
# def sweep_k_for_target_fast(
#     df,
#     k_list,
#     target=6,
#     transform='arcsine',       # or 'logit'
#     metric='euclidean',
#     scaling=False,
#     pca_or_not=True,
#     n_pcs=None,
#     kernel_type='laplacian',
#     nan_threshold=0.7,
#     nan_method='drop',
#     res_init=1.0,
#     res_min=0.02,
#     res_max=10.0,
#     up_factor=1.5,
#     max_iter=12,
#     silhouette_sample=None,    # e.g., 2000 to speed up silhouette on large N
#     seed=42
# ):
#     """Sweep over multiple k values to find best clustering."""
#     # --- Prepare embedding once ---
#     df_used, X_pca, n_comp = prepare_embedding(
#         df, nan_threshold, nan_method, transform,
#         scaling, pca_or_not, n_pcs, seed
#     )

#     # --- Compute neighbors once at max k ---
#     max_k = int(max(k_list))
#     dist_max, idx_max, self_included = kneighbors_upto_k(
#         X_pca, max_k=max_k, metric=metric
#     )

#     all_rows = []
#     best = None
#     best_score = (-np.inf, np.inf)  # (silhouette, |n_clusters - target|)

#     # --- Loop over k ---
#     for k in k_list:
#         k = int(k)
#         dist = dist_max[:, :k]
#         idx = idx_max[:, :k]

#         # Build graph for this k
#         g = graph_from_neighbors(dist, idx, self_included, kernel_type=kernel_type)

#         # Sweep resolution on this fixed graph
#         hist = search_resolution_on_graph(
#             g, X_pca, target=target, res_init=res_init, res_min=res_min,
#             res_max=res_max, up_factor=up_factor, max_iter=max_iter,
#             metric=metric, seed=seed, silhouette_sample=silhouette_sample
#         )

#         # Collect summary
#         for h in hist:
#             all_rows.append({
#                 'k': k,
#                 'resolution': h['resolution'],
#                 'n_clusters': h['n_clusters'],
#                 'silhouette': h['silhouette'],
#                 'calinski_harabasz': h['calinski_harabasz'],
#                 'davies_bouldin': h['davies_bouldin'],
#                 'leiden_quality': h['leiden_quality'],
#                 'modularity': h['modularity']
#             })

#         # Select best for this k (prefer exact target + high silhouette)
#         best_k_run = sorted(
#             hist,
#             key=lambda h: (
#                 -(h['n_clusters'] == target),
#                 h['silhouette'] or -1e9,
#                 -abs(h['n_clusters'] - target)
#             )
#         )[-1]

#         # Track global best
#         cur_sil = best_k_run['silhouette'] or -np.inf
#         cur_delta = abs(best_k_run['n_clusters'] - target)

#         if (cur_sil > best_score[0]) or (cur_sil == best_score[0] and cur_delta < best_score[1]):
#             best_score = (cur_sil, cur_delta)
#             best = {
#                 'k': k,
#                 'resolution': best_k_run['resolution'],
#                 'n_clusters': best_k_run['n_clusters'],
#                 'clusters': best_k_run['clusters'],
#                 'part': best_k_run['part'],
#                 'graph': g,
#                 'X_pca': X_pca,
#                 'df_used': df_used
#             }

#     summary = pd.DataFrame(all_rows)
#     return summary, best


In [None]:
def run_pipelines_on_genes(gene_names, gene_df, pipeline_configs):
    """
    Run multiple clustering pipelines on multiple gene dataframes.

    Parameters
    ----------
    genes_names
        List of gene names
    genes_df
        List of the dataframes associated with these genes
    pipeline_configs : list of dict
        Each dictionary must contain at least {'name': str}, plus kwargs for clustering_final().

    Returns
    -------
    results_df : pandas.DataFrame
        Tidy DataFrame with clustering metrics per gene × pipeline.
    outputs : dict
        Nested dictionary for detailed outputs:
        outputs[gene_id][pipeline_name] = {
            'X_pca', 'part', 'clusters', 'metrics', 'graph'
        }
    """
    rows = []
    outputs = {}
    # Create the dictionary
    genes_dict = dict(zip(gene_names, gene_df))

    for gene_id, df in genes_dict.items():
        outputs[gene_id] = {}

        for cfg in pipeline_configs:
            name = cfg.get('name', 'pipeline')
            kwargs = {k: v for k, v in cfg.items() if k != 'name'}

            try:
                _, X_pca, part, clusters, metrics = clustering_final(df, **kwargs)

                row = {
                    'gene': gene_id,
                    'pipeline': name,
                    'silhouette': metrics.get('silhouette'),
                    'calinski_harabasz': metrics.get('calinski_harabasz'),
                    'davies_bouldin': metrics.get('davies_bouldin'),
                    'leiden_quality': metrics.get('leiden_quality'),
                    'modularity': metrics.get('modularity'),
                    'n_clusters': len(np.unique(clusters)),
                    'n_nodes': len(df),
                    'n_edges': len(part.graph.es) if hasattr(part, 'graph') else None
                }

                rows.append(row)

                outputs[gene_id][name] = {
                    'X_pca': X_pca,
                    'part': part,
                    'clusters': clusters,
                    'metrics': metrics
                }

            except Exception as e:
                row = {
                    'gene': gene_id,
                    'pipeline': name,
                    'error': str(e)
                }
                rows.append(row)
                outputs[gene_id][name] = {'error': str(e)}

    results_df = pd.DataFrame(rows)
    return results_df, outputs


In [None]:
def plot_violin_scatter(
    results_df,
    metrics=('silhouette', 'calinski_harabasz', 'davies_bouldin',
             'leiden_quality', 'modularity', 'n_clusters')
):
    """
    Violin + scatter (strip) plots comparing pipelines for each metric across genes.

    Parameters
    ----------
    results_df : pandas.DataFrame
        Output from run_pipelines_on_genes().
    metrics : tuple of str
        Metrics to visualize.

    Returns
    -------
    figs : dict
        Dictionary of matplotlib Figures, keyed by metric name.
    """
    # Melt to long format
    df_long = results_df.melt(
        id_vars=['gene', 'pipeline'],
        value_vars=[m for m in metrics if m in results_df.columns],
        var_name='metric',
        value_name='value'
    )

    # Drop missing or failed results
    df_long = df_long.dropna(subset=['value'])

    # Plot one figure per metric
    figs = {}
    for metric_name, sub in df_long.groupby('metric'):
        plt.figure(figsize=(7, 5))
        ax = sns.violinplot(
            data=sub,
            x='pipeline',
            y='value',
            inner=None,
            cut=0
        )
        sns.stripplot(
            data=sub,
            x='pipeline',
            y='value',
            color='k',
            size=4,
            alpha=0.6,
            jitter=0.2
        )
        ax.set_title(f'Comparison by {metric_name}')
        ax.set_xlabel('')
        ax.set_ylabel(metric_name)
        ax.grid(axis='y', linestyle='--')
        plt.tight_layout()

        figs[metric_name] = ax.get_figure()
        

    return figs


In [None]:
def plot_compare_pipelines_grid(
    results_df: pd.DataFrame,
    pipelines_order=('pca_euclidean', 'cosine_no_pca'),
    metrics=('silhouette', 'calinski_harabasz', 'davies_bouldin',
             'leiden_quality', 'modularity', 'n_clusters'),
    kind='violin',                # 'violin' or 'box'
    show_points=True,
    connect_pairs=True,           # draw line per gene connecting the two pipelines
    figsize=None,
    point_kwargs=None,
    violin_kwargs=None,
    box_kwargs=None
):
    """
    Make ONE figure with one subplot per metric; within each subplot, 
    show pipelines side-by-side.

    Parameters
    ----------
    results_df : pd.DataFrame
        Must contain columns ['gene', 'pipeline', <metrics...>].
    pipelines_order : tuple
        Order of pipelines to display.
    metrics : tuple
        Which metrics to plot.
    kind : {'violin', 'box'}
        Plot type for distributions.
    show_points : bool
        Whether to overlay individual gene points.
    connect_pairs : bool
        Whether to connect paired genes across pipelines.
    figsize : tuple or None
        Figure size; computed automatically if None.
    *_kwargs : dict or None
        Style options passed to Seaborn plotting functions.

    Returns
    -------
    fig, axes : matplotlib Figure and Axes objects
    """

    # --- Filter metrics and pipelines ---
    metrics = [m for m in metrics if m in results_df.columns]
    if not metrics:
        raise ValueError("No requested metrics found in results_df.")

    df = results_df.copy()
    df = df[df['pipeline'].isin(pipelines_order)]
    if df.empty:
        raise ValueError("results_df has no rows for the requested pipelines_order.")

    # Ensure pipelines are ordered consistently
    df['pipeline'] = pd.Categorical(df['pipeline'],
                                    categories=list(pipelines_order),
                                    ordered=True)

    # --- Melt to long format ---
    df_long = (
        df.melt(
            id_vars=['gene', 'pipeline'],
            value_vars=metrics,
            var_name='metric',
            value_name='value'
        )
        .dropna(subset=['value'])
    )

    # --- Set up figure ---
    n_cols = len(metrics)
    if figsize is None:
        figsize = (3.2 * n_cols, 4.2)  # scale width by number of metrics

    fig, axes = plt.subplots(1, n_cols, figsize=figsize, sharey=False)
    if n_cols == 1:
        axes = [axes]

    # --- Default style dictionaries ---
    if point_kwargs is None:
        point_kwargs = dict(color='k', s=18, alpha=0.7)
    if violin_kwargs is None:
        violin_kwargs = dict(inner=None, cut=0, linewidth=0)
    if box_kwargs is None:
        box_kwargs = dict(fliersize=0, linewidth=1.2)

    # --- Color palette ---
    palette = sns.color_palette('Set2', n_colors=len(pipelines_order))
    pipeline_colors = dict(zip(pipelines_order, palette))

    # --- Plot each metric ---
    for ax, metric_name in zip(axes, metrics):
        sub = df_long[df_long['metric'] == metric_name]

        # Main distribution plot
        if kind == 'violin':
            sns.violinplot(
                data=sub,
                x='pipeline', y='value',
                order=list(pipelines_order),
                palette=pipeline_colors,
                ax=ax, **violin_kwargs
            )
        elif kind == 'box':
            sns.boxplot(
                data=sub,
                x='pipeline', y='value',
                order=list(pipelines_order),
                palette=pipeline_colors,
                ax=ax, **box_kwargs
            )
        else:
            raise ValueError("kind must be 'violin' or 'box'")

        # Overlay points
        if show_points:
            sns.stripplot(
                data=sub,
                x='pipeline', y='value',
                order=list(pipelines_order),
                dodge=False, jitter=0.15,
                color='k', size=4, alpha=0.6, ax=ax
            )

        # Connect paired points (same gene across pipelines)
        if connect_pairs and len(pipelines_order) == 2:
            p0, p1 = pipelines_order

            # Pivot to wide format (gene × pipeline)
            wide = sub.pivot_table(index='gene', columns='pipeline',
                                   values='value', aggfunc='first')
            wide = wide.dropna(subset=[p0, p1], how='any')

            # Draw connecting lines
            for _, row in wide.iterrows():
                ax.plot([0, 1], [row[p0], row[p1]],
                        color='gray', alpha=0.35, linewidth=1)

        # Axis formatting
        ax.set_title(metric_name.replace('_', ' ').title())
        ax.set_xlabel('')
        ax.grid(axis='y', linestyle='--', alpha=0.25)

    plt.tight_layout()
    return fig, axes


In [None]:
def silhouettes_multi(
    X_space,
    clusters,
    metric_main='cosine',
    extra_metrics=('correlation', 'euclidean')
):
    """
    Compute silhouette scores for multiple distance metrics.

    Args:
        X_space (array-like): Feature matrix (n_samples x n_features).
        clusters (array-like): Cluster labels.
        metric_main (str): Primary metric to evaluate (default: 'cosine').
        extra_metrics (tuple): Additional metrics to test (default: ('correlation', 'euclidean')).

    Returns:
        dict: Silhouette scores for each metric, using direct or precomputed distances.
    """
    results = {}

    # Main metric
    try:
        results[f'silhouette_{metric_main}'] = silhouette_score(
            X_space, clusters, metric=metric_main
        )
    except Exception:
        D = pairwise_distances(X_space, metric=metric_main)
        results[f'silhouette_{metric_main}'] = silhouette_score(
            D, clusters, metric='precomputed'
        )

    # Extra metrics
    for m in extra_metrics:
        try:
            results[f'silhouette_{m}'] = silhouette_score(
                X_space, clusters, metric=m
            )
        except Exception:
            D = pairwise_distances(X_space, metric=m)
            results[f'silhouette_{m}'] = silhouette_score(
                D, clusters, metric='precomputed'
            )

    return results


In [None]:
## 3.1 ALTERNATE PIPELINE FOR CLUSTERING

def filter_by_missingness(
    df: pd.DataFrame,
    min_bin_non_nan_frac: float = 0.2,  # keep bins (columns) seen in >= 20% of reads
    min_read_non_nan_frac: float = 0.5  # keep reads (rows) with >= 50% bins observed
) -> pd.DataFrame:
    """Filter reads (rows) and bins (columns) based on missingness thresholds."""
    # rows = reads, columns = bins
    if df.shape[0] < 2 or df.shape[1] < 2:
        return df.copy()

    keep_cols = df.notna().mean(axis=0) >= float(min_bin_non_nan_frac)
    keep_rows = df.notna().mean(axis=1) >= float(min_read_non_nan_frac)
    df_f = df.loc[keep_rows, keep_cols].copy()

    return df_f


def masked_pearson_correlation_matrix(
    df: pd.DataFrame,
    min_overlap: int = 30
) -> pd.DataFrame:
    """
    Pairwise-complete Pearson correlation between reads (rows), computed only on shared
    non-NaN bins, requiring at least min_overlap shared bins.

    Returns an (N x N) DataFrame; entries with overlap < min_overlap are NaN.
    """
    # df.T.corr computes correlation among rows of df
    R = df.T.corr(min_periods=int(min_overlap))
    return R


def overlap_matrix(df: pd.DataFrame) -> np.ndarray:
    """Compute number of shared non-NaN bins for each read pair (N x N)."""
    M = (~df.isna()).to_numpy(dtype=bool)  # shape (N, D)
    O = M @ M.T
    return O


def build_masked_corr_knn_graph(
    df: pd.DataFrame,
    k: int = 15,
    min_overlap: int = 30,
    positive_only: bool = True,
    shrink_c: float = 30.0,  # overlap-based shrink: w *= n / (n + shrink_c)
    mutual: bool = True
):
    """
    Build a weighted kNN graph from masked Pearson correlations.

    - df: rows = reads, columns = bins, values = methylation fractions (NaN allowed).

    Returns:
        igraph.Graph, weight list (in edge order), neighbor index list per node.
    """
    N = df.shape[0]
    if N < 2:
        raise ValueError("Need at least 2 reads to build a graph.")

    # 1) Pairwise-complete correlation and overlap counts
    R_df = masked_pearson_correlation_matrix(df, min_overlap=min_overlap)  # (N x N)
    O = overlap_matrix(df)  # (N x N), shared bin counts

    # 2) Convert to weights
    R = R_df.to_numpy()
    mask_sufficient = (O >= int(min_overlap))
    W = np.where(mask_sufficient, R, 0.0)

    # Keep only nonnegative similarities if desired
    if positive_only:
        W = np.maximum(W, 0.0)

    # Optional shrink by overlap size (softly downweight low-overlap edges)
    if shrink_c is not None and shrink_c > 0:
        shrink = O / (O + float(shrink_c))
        W = W * shrink

    # Zero diagonal
    np.fill_diagonal(W, 0.0)

    # 3) Build kNN (top-k by weight, skipping zeros)
    neighbors = []
    weights = []

    for i in range(N):
        row = W[i].copy()
        idx_sorted = np.argsort(row)[::-1]  # sort by weight descending
        idx_sorted = idx_sorted[row[idx_sorted] > 0]  # filter out zeros
        topk = idx_sorted[:k]
        neighbors.append(topk.tolist())
        weights.append(row[topk].tolist())

    # 4) Symmetrize (mutual kNN) and build edge list
    edges = {}

    if mutual:
        neighbor_sets = [set(ns) for ns in neighbors]
        for i in range(N):
            for j in neighbors[i]:
                if i in neighbor_sets[j]:
                    a, b = (i, j) if i < j else (j, i)
                    wij = W[i, j]
                    wji = W[j, i]
                    w = max(wij, wji) # here i could change and put mean also
                    if w > 0:
                        edges[(a, b)] = max(edges.get((a, b), 0.0), w)
    else:
        for i in range(N):
            for j, w in zip(neighbors[i], weights[i]):
                a, b = (i, j) if i < j else (j, i)
                edges[(a, b)] = max(edges.get((a, b), 0.0), w)

    e_list = list(edges.keys())
    w_list = [edges[e] for e in e_list]

    g = ig.Graph(n=N, edges=e_list, directed=False)
    g.es["weight"] = w_list

    return g, w_list, neighbors


def alternate_clustering(
    df: pd.DataFrame,
    apply_filter: bool = True,
    min_bin_non_nan_frac: float = 0.2,
    min_read_non_nan_frac: float = 0.5,
    transform: str | None = None,  # None | 'arcsine' | 'logit'
    min_overlap: int = 30,
    k: int = 15,
    positive_only: bool = True,
    shrink_c: float = 30.0,
    mutual: bool = True,
    leiden_resolution: float = 1.0,
    seed: int = 42
):
    """
    Pipeline: Masked Pearson correlation (no imputation)
    -> weighted mutual kNN graph -> Leiden clustering.

    Args:
        df: rows = reads, columns = bins, values in [0,1] with NaNs allowed.

    Returns:
        clusters (np.ndarray of shape (N,)),
        graph (igraph.Graph),
        df_used (pd.DataFrame after optional filtering/transform)
    """
    if apply_filter:
        df_used = filter_by_missingness(
            df,
            min_bin_non_nan_frac=min_bin_non_nan_frac,
            min_read_non_nan_frac=min_read_non_nan_frac
        )
    else:
        df_used = df.copy()

    if df_used.shape[0] < 2 or df_used.shape[1] < 2:
        raise ValueError("Not enough reads or bins after filtering.")

    # Optional transform for proportions
    if transform is not None:
        if transform == 'arcsine':
            arr = np.arcsin(np.sqrt(np.clip(df_used.to_numpy(float), 0.0, 1.0)))
            df_used = pd.DataFrame(arr, index=df_used.index, columns=df_used.columns)
        elif transform == 'logit':
            Xc = np.clip(df_used.to_numpy(float), 1e-3, 1 - 1e-3)
            arr = np.log(Xc / (1 - Xc))
            df_used = pd.DataFrame(arr, index=df_used.index, columns=df_used.columns)
        else:
            raise ValueError("transform must be None, 'arcsine', or 'logit'")

    # Build graph from masked correlation
    g, w_list, _ = build_masked_corr_knn_graph(
        df_used,
        k=k,
        min_overlap=min_overlap,
        positive_only=positive_only,
        shrink_c=shrink_c,
        mutual=mutual
    )

    # Leiden clustering
    part = la.find_partition(
        g,
        la.RBConfigurationVertexPartition,
        weights=g.es["weight"],
        resolution_parameter=float(leiden_resolution),
        seed=int(seed)
    )

    clusters = np.array(part.membership, dtype=int)
    return clusters, g, df_used


In [None]:
def plot_umap(
    X,
    clusters,
    n_neighbors=15,
    min_dist=0.1,
    metric='euclidean',      # for Pipeline 2, use 'euclidean'
    transform=None,          # None | 'logit' | 'arcsine' (use only if X are raw proportions)
    n_pcs=None,              # set if X are raw features; None if X are already PCs
    standardize=False,       # True if using raw features without PCA
    seed=42,
    palette=None,            # optional list/array or matplotlib colormap name
    title="UMAP embedding",
    gene = 'Gene'
):

    X_in = np.asarray(X, dtype=float)

    # Optional transform for proportions (only if X are raw fractions)
    if transform is not None:
        if transform == 'logit':
            Xc = np.clip(X_in, 1e-3, 1 - 1e-3)
            X_in = np.log(Xc / (1 - Xc))
        elif transform == 'arcsine':
            X_in = np.arcsin(np.sqrt(np.clip(X_in, 0.0, 1.0)))
        else:
            raise ValueError("transform must be None, 'logit', or 'arcsine'")

    # Optional standardization (useful if running UMAP on raw features)
    if standardize:
        scaler = StandardScaler(with_mean=True, with_std=True)
        X_in = scaler.fit_transform(X_in)

    # Optional PCA (skip if X are already PCA scores)
    if n_pcs is not None and 0 < n_pcs < X_in.shape[1]:
        pca = PCA(n_components=n_pcs, random_state=seed)
        X_umap = pca.fit_transform(X_in)
    else:
        X_umap = X_in

    # Sanity checks for metric
    if metric == 'jaccard':
        # Warn if data are not binary
        if not np.array_equal(X_umap, X_umap.astype(bool)) and not np.array_equal(
            X_umap, (X_umap > 0).astype(int)
        ):
            raise ValueError(
                "Jaccard metric requires binary data; "
                "use euclidean/cosine/correlation for continuous features."
            )

    # UMAP embedding
    reducer = umap.UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        metric=metric,
        random_state=seed
    )
    embedding = reducer.fit_transform(X_umap)

    # Colors
    unique_clusters = np.unique(clusters)
    if palette is None:
        # fallback to a matplotlib qualitative colormap
        cmap = plt.get_cmap('tab20')
        color_map = {c: cmap(i % cmap.N) for i, c in enumerate(unique_clusters)}
    elif isinstance(palette, str):
        cmap = plt.get_cmap(palette)
        color_map = {c: cmap(i % cmap.N) for i, c in enumerate(unique_clusters)}
    else:
        # palette is a list/array
        color_map = {c: palette[i % len(palette)] for i, c in enumerate(unique_clusters)}

    colors = [color_map[c] for c in clusters]

    # Plot
    fig = plt.figure(figsize=(10, 7))
    plt.scatter(
        embedding[:, 0],
        embedding[:, 1],
        c=colors,
        alpha=0.85,
    )
    plt.xlabel("UMAP1")
    plt.ylabel("UMAP2")
    plt.title(f'{gene} : {title}')
    plt.grid()

    # Legend
    handles = [
        plt.Line2D(
            [0], [0],
            marker='o',
            color='w',
            label=str(c),
            markerfacecolor=color_map[c],
            markersize=8
        )
        for c in unique_clusters
    ]
    plt.legend(
        handles=handles,
        title="Cluster",
        bbox_to_anchor=(1.02, 1.0),
        loc='upper left',
        frameon=False
    )
    plt.tight_layout()

    return embedding, fig


#### Annex functions needed to gather clustering information - dictionnary manipulation + start/end/coord
1. 'dict_id_cluster_color'
2. 'dict_to_df'
3. 'merge'
4. 'start_end_center'

In [None]:
## 1. DICTIONNARY TO KEEP TRACK OF THE ASSOCIATION READID, CLUSTER AND COLOR
def dict_id_cluster_color(
        df: pd.DataFrame,
        clusters : list,
        hex_colors: list,
):
    """
    Creating a dictionnary that keeps track of the association readid, cluster associated, and the color attributed
    df: dataframe that is binned, in a matrix form without any missing values (the dataframe used to cluster)
    clusters: the list containing the clusters after clustering(df)
    hex_colors: the hand defined color list (to not rely on the colormaps)
    """
    #creating a dictionnary to get the colors associated to the clusters

    unique_clusters = np.unique(clusters)
    cluster_colors = {
        cluster_id: hex_colors[i % len(hex_colors)]
        for i, cluster_id in enumerate(unique_clusters)
    }

#creating a dictionnary where the key is the readid and the values both the cluster id and the color
    read_dict = {
        read_id: {
            "cluster": cluster_id,
            "color": cluster_colors[cluster_id]
        }
        for read_id, cluster_id in zip(df.index.get_level_values('readid'), clusters)
    }
    return read_dict

In [None]:
## 2. TRANSFORMING THE DICTIONNARY INTO A DATAFRAME
def dict_to_df(
        read_dict: dict
):
    """
    Transforming the dictionnary into a dataframe, (to be able to merge it with the original dataframe later on)
    read_dict: the dictionnary created with the function dict_id_cluster_color
    """
    df_dict = pd.DataFrame.from_dict(read_dict, orient='index')
    df_dict.index.name = 'readid'
    df_dict.reset_index(inplace=True)
    df_dict.rename(columns={"cluster": "locus_cluster", "color": "locus_cluster_color"}, inplace=True)
    
    return df_dict

In [None]:
## 3. MERGING THE DICTIONNARY DATAFRAME WITH THE ORIGINAL DATAFRAME (TO GET THE CLUSTER AND COLOR INFORMATION FOR EACH READ)
def merge(
        df,
        df_dict: pd.DataFrame  
):
    # Ensure readid is a column
    if "readid" not in df.columns:
        df_merged = df.reset_index().rename(columns={"index": "readid"}).copy()
    else:
        df_merged = df.copy()    

    #Merging the two dataframes on'readid' to get the cluster and color information for each read
    df_merged = df_merged.merge(df_dict, on="readid", how="inner")
    df_merged.dropna(subset=['locus_cluster'], inplace=True)
    
    return df_merged

In [None]:
## 4. FUNCTION TO GET THE START, END AND CENTER COORDINATES OF A GENE
def start_end_center(
        df: pd.DataFrame
):
    if 'read_start' in df.columns and 'read_end' in df.columns and 'pol2_pos' in df.columns:
        start = df['read_start'].min()
        end = df['read_end'].max()
        center_coord = df['pol2_pos'].unique()[0]  

        return start, end, center_coord
    else:
        raise ValueError("DataFrame must contain 'read_start', 'read_end', and 'pol2_pos' columns.")

#### Analysis functions: heatmaps, average plots
1. 'summary'
2. (unused )'plot_avg_methylation_profile'
- (unused) 'plot_avg_methylation_profile_bulk'
- 'compute_bulk_centroids'
- 'compute_gene_centroids'
- 'plot_centroids_with_shading'

In [None]:
## 1. TO COMPARE THE LOCUS SPECIFIC CLUSTERS RELATIVELY TO THE BULK CLUSTERS AND THE GROUPS

def summary(
        df: pd.DataFrame,
        clusters: list,
        gene='Gene'
):
    df_summary = df.copy()
    df_summary['locus_cluster'] = clusters

    total_reads = len(df_summary)  # <-- denominator for global %s

    fig, axes = plt.subplots(1, 2, figsize=(16, 6), constrained_layout=True)

    # --- First heatmap: per group (global percentages) ---
    if 'group' in df_summary.index.names:
        # group is part of the index
        summary_per_group = (
            df_summary[['locus_cluster']]
            .groupby(df_summary.index.get_level_values('group'))['locus_cluster']
            .value_counts()
            .div(total_reads).mul(100)
            .reset_index(name="percentage")
        )
    else:
        # group is a normal column
        summary_per_group = (
            df_summary[['group', 'locus_cluster']]
            .groupby('group')['locus_cluster']
            .value_counts()
            .div(total_reads).mul(100)
            .reset_index(name="percentage")
        )

    pivot_group = summary_per_group.pivot(
        index="group", columns="locus_cluster", values="percentage"
    ).fillna(0)

    sns.heatmap(pivot_group, annot=True, fmt=".2f", cmap="viridis", ax=axes[0])
    axes[0].set_ylabel("Group")
    axes[0].set_xlabel("Cluster")
    axes[0].set_title("Reads per cluster in each group (% of all reads)")

    # --- Second heatmap: per bulk cluster (global percentages) ---

    if 'cluster' in df_summary.index.names:
        # group is part of the index
        summary_per_bulk_cluster = (
            df_summary[['locus_cluster']]
            .groupby(df_summary.index.get_level_values('cluster'))['locus_cluster']
            .value_counts()
            .div(total_reads).mul(100)
            .reset_index(name="percentage")
        )
    else:
        # group is a normal column
        summary_per_bulk_cluster = (
            df_summary[['cluster', 'locus_cluster']]
            .groupby('cluster')['locus_cluster']
            .value_counts()
            .div(total_reads).mul(100)
            .reset_index(name="percentage")
        )

    pivot_cluster = summary_per_bulk_cluster.pivot(
        index="cluster", columns="locus_cluster", values="percentage"
    ).fillna(0)

    sns.heatmap(pivot_cluster, annot=True, fmt=".2f", cmap="viridis", ax=axes[1])
    axes[1].set_ylabel("Bulk Cluster")
    axes[1].set_xlabel("Cluster")
    axes[1].set_title("Reads per cluster in each bulk cluster (% of all reads)")

    plt.title(f'{gene} : Heatmaps of repartitions')
    plt.show()


In [None]:
### (unused) 2. FUNCTION TO PLOT THE AVERAGE METHYLATION PROFILES PER CLUSTER (WITH THE COLORS ATTRIBUTED TO EACH CLUSTER)
def plot_avg_methylation_profile(
    df: pd.DataFrame,
    df_dict: pd.DataFrame,  # dataframe made out of the read_dict
    start: int,
    end: int,
    center_coord: int,
    read_dict: dict, # {read_id: {'cluster': cluster_id, 'color': hex_color}}
    gene= 'Gene'
):
    """
    Plot average methylation profiles per cluster using cluster-specific colors.

    Parameters
    ----------
    df : pd.DataFrame
        Each row = read, columns = genomic positions (bins), values = mean methylation (0-1)
    start : int
        Start genomic coordinate for plotting
    end : int
        End genomic coordinate for plotting
    center_coord : int
        Central position (e.g., Pol2)
    partition : ig.VertexClustering
        Clustering result (used to determine number of clusters if needed)
    read_dict : dict
        Mapping of read IDs to cluster and color: {read_id: {'cluster': id, 'color': '#hex'}}
    """
    # -------------------------
    # Step 1: Prepare DataFrame
    # -------------------------
    df_avg= df.copy()
    df_avg = df_avg.reset_index(level='readid', drop=False) 
    df_avg.index = range(len(df_avg)) 

    # -------------------------
    # Step 2: Merge with df_dict
    # -------------------------
    df_merged = pd.merge(df_avg, df_dict, on='readid', how='inner')

    # -------------------------
    # Step 3: Identify numeric position columns
    # -------------------------
    metadata_cols = ['readid', 'locus_cluster', 'locus_cluster_color']
    position_cols = [col for col in df_merged.columns if col not in metadata_cols]
    positions_sorted = sorted([int(col) for col in position_cols])

    # -------------------------
    # Step 4: Cluster colors mapping
    # -------------------------
    cluster_colors = {}
    for rid, info in read_dict.items():
        cid = int(info['cluster'])
        if cid not in cluster_colors:
            cluster_colors[cid] = info['color']

    # -------------------------
    # Step 5: Compute cluster proportions
    # -------------------------
    unique_clusters = sorted(df_merged['locus_cluster'].dropna().astype(int).unique())
    n_clusters = len(unique_clusters)
    proportion_df = (
        df_merged['locus_cluster']
        .value_counts(normalize=True)
        .mul(100)
        .reindex(unique_clusters, fill_value=0)
        .reset_index(name="percentage")
        .rename(columns={"index": "locus_cluster"})
    )

    # -------------------------
    # Step 6: Create subplots
    # -------------------------
    height_ratios = proportion_df['percentage'].values
    fig, axes = plt.subplots(n_clusters, 1, figsize=(10, 12), sharex=True,
                             gridspec_kw={'height_ratios': height_ratios})
    if n_clusters == 1:
        axes = [axes]

    fig.suptitle(
    f"{gene}: Average DNA Methylation Profiles per Cluster",  # Global title
    fontsize=14,  # Title font size
    y=0.97  # Adjust title vertical position (closer to the top of the figure)
)
    # -------------------------
    # Step 7: Plot each cluster
    # -------------------------
    meth_arrays=[]

    for i, cluster_id in enumerate(unique_clusters):
        ax = axes[i]
        cluster_rows = df_merged[df_merged['locus_cluster'] == cluster_id]

        if cluster_rows.empty:
            continue

        # Compute mean methylation only over position columns (integers!)
        df_meth = cluster_rows[positions_sorted]
        meth_sorted = df_meth.mean(axis=0).values


        # Smoothing
        meth_smooth = gaussian_filter1d(meth_sorted, sigma=0.3)

        meth_arrays.append(meth_sorted)

        # Cluster color
        color = cluster_rows['locus_cluster_color'].iloc[0]

        # Plot

        # Plot with the cluster color
        ax.set_facecolor(color)

        # Fill under the curve for the background as white
        ax.fill_between(positions_sorted, meth_smooth, color='white')

        # Optionally, overlay a line for contrast
        # ax.plot(positions_sorted, meth_smooth, color='white', linewidth=1)

        # Y-axis label with percentage
        percentage = proportion_df.loc[
            proportion_df['locus_cluster'] == cluster_id, 'percentage'
        ].values[0]
        ax.set_ylabel(f"Cluster {cluster_id} | {percentage:.1f}%", fontsize=8)

        # Axis limits and grid
        ax.set_ylim(0, 1.1)
        ax.set_xlim(min(positions_sorted), max(positions_sorted))
    
        # Assign new labels
        ax.set_xticklabels([str(start), "50", str(end)])
        ax.grid(True, axis='x', linestyle='--', linewidth=0.5, color='gray')

        # Reference lines
        ax.axhline(y=1, color='grey', linestyle='--', linewidth=0.5)
        ax.axvline(x=0, color="#C80028", linestyle="-", linewidth=2, label="Pol2 position")
        ax.xaxis.set_major_locator(MultipleLocator(100))
    

    # -------------------------
    # Step 8: Global labels & legend
    # -------------------------
    

    # Show x tick labels only on the bottom axis
    for ax in axes[:-1]:
        ax.tick_params(labelbottom=False)

    # Ticks every 100 on the bottom axis
    axes[-1].xaxis.set_major_locator(MultipleLocator(100))

    # Format bottom tick labels as absolute genomic coordinates
    def rel_to_abs_label(x, pos):
        # Optionally use thousands separators: f"{int(x + center_coord):,}"
        return f"{int(x + center_coord)}"

    axes[-1].xaxis.set_major_formatter(FuncFormatter(rel_to_abs_label))

    # Hide any potential offset text
    axes[-1].get_xaxis().get_offset_text().set_visible(False)

    # Global labels & legend
    fig.text(0.95, 0.5, 'Methylation level', va='center', ha='center', rotation=-90, fontsize=10)
    fig.text(
        0.05, 0.5,
        "Cluster number | Cluster proportion\n(ordered by cluster mean methylation)",
        fontsize=10, multialignment='center', rotation=90, va='center', ha='center'
    )
    fig.text(0.5, 0.04, "Genomic coordinate", va='center', ha='center', fontsize=10)

    center_line = mlines.Line2D([], [], color='red', linestyle='-', linewidth=1, label='pol2 position')
    fig.legend(handles=[center_line], loc='upper right', bbox_to_anchor=(0.97, 0.955), fontsize=9, frameon=False)

    plt.subplots_adjust(hspace=0.05, top=0.92)
    plt.setp(axes[-1].get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')
    axes[-1].tick_params(axis='x', labelsize=8)
    fig.subplots_adjust(bottom=0.18)  # extra bottom margin to avoid clipping

    
    plt.show()

    return position_cols, meth_arrays


In [None]:
# # (unused) FUNCTION TO PLOT AVERAGE METHYLATION PROFILES FOR THE BULK DATA
# def plot_avg_methylation_profile_bulk(
#     df: pd.DataFrame,
#     start: int | None = None,
#     end: int | None = None,
#     hex_colors=None,
#     smooth_sigma: float = 0.5,
#     order_by: str = "mean",  # 'mean' | 'cluster' | None
#     use_cov_for_proportion: bool = True
# ):
#     """
#     Plot average methylation profiles per cluster for bulk data.

#     Input df columns:
#     - 'cluster': cluster id (int or str)
#     - 'C_pos': genomic coordinate (int; already absolute or relative)
#     - 'meth': average methylation at C_pos for that cluster (0..1)
#     - 'cov': per-cluster coverage/size (assumed constant across rows or summable)

#     Args:
#     - start, end: optional x-range to display; if None, inferred from data.
#     - hex_colors: list of color hex strings; if None, use matplotlib tab20.
#     - smooth_sigma: Gaussian smoothing sigma for the curve.
#     - order_by: order subplots by 'mean' methylation, by 'cluster', or leave as-is (None).
#     - use_cov_for_proportion: if True, compute proportion from cov; else from row counts.

#     Returns:
#     - fig, axes
#     """

#     # Validate input
#     required = {'cluster', 'C_pos', 'meth', 'cov'}
#     missing = required - set(df.columns)
#     if missing:
#         raise ValueError(f"Dataframe missing columns: {missing}")

#     # Ensure numeric types
#     df = df.copy()
#     for c in ['C_pos', 'meth', 'cov']:
#         df[c] = pd.to_numeric(df[c], errors='coerce')
#     df = df.dropna(subset=['cluster', 'C_pos', 'meth'])

#     # Build cluster list and color map
#     unique_clusters = list(pd.unique(df['cluster']))
#     n_clusters = len(unique_clusters)

#     if hex_colors is None:
#         import matplotlib.cm as cm
#         cmap = cm.get_cmap('tab20', n_clusters)
#         hex_colors = [cm.colors.to_hex(cmap(i)) for i in range(n_clusters)]
#     color_map = {cid: hex_colors[i % len(hex_colors)] for i, cid in enumerate(unique_clusters)}

#     # Compute cluster stats
#     g = df.groupby('cluster')
#     cluster_mean = g['meth'].mean()

#     # Proportion from coverage
#     if use_cov_for_proportion:
#         cov_first = g['cov'].first()
#         cov_var = g['cov'].var()
#         if np.nanmax(cov_var.fillna(0)) > 0:
#             cov = g['cov'].sum()
#         else:
#             cov = cov_first
#         proportions = (cov / cov.sum()) * 100.0
#     else:
#         counts = g.size()
#         proportions = (counts / counts.sum()) * 100.0

#     # Order clusters
#     if order_by == "mean":
#         ordered_clusters = list(cluster_mean.sort_values(ascending=False).index)
#     elif order_by == "cluster":
#         try:
#             ordered_clusters = sorted(unique_clusters)
#         except Exception:
#             ordered_clusters = unique_clusters
#     else:
#         ordered_clusters = unique_clusters

#     # Height ratios from proportions (avoid zeros)
#     height_ratios = [max(1e-3, float(proportions.get(cid, 0.0))) for cid in ordered_clusters]

#     # Figure and axes
#     fig, axes = plt.subplots(
#         n_clusters,
#         1,
#         figsize=(16, 18),
#         sharex=True,
#         gridspec_kw={'height_ratios': height_ratios}
#     )
#     if n_clusters == 1:
#         axes = [axes]

#     fig.suptitle("Average DNA Methylation Profiles per Cluster", fontsize=14, y=0.97)

#     # Plot each cluster
#     for i, cid in enumerate(ordered_clusters):
#         ax = axes[i]
#         sub = df[df['cluster'] == cid].copy()
#         if sub.empty:
#             ax.set_visible(False)
#             continue

#         # Aggregate duplicates at the same position
#         sub = sub.groupby('C_pos', as_index=False).agg({'meth': 'mean', 'cov': 'first'})
#         sub = sub.sort_values('C_pos')

#         # Extract arrays
#         x = sub['C_pos'].to_numpy()
#         y = sub['meth'].to_numpy()

#         # Optional smoothing
#         if smooth_sigma and smooth_sigma > 0:
#             y_smooth = gaussian_filter1d(y, sigma=float(smooth_sigma))
#         else:
#             y_smooth = y

#         color = color_map[cid]

#         # Background
#         ax.set_facecolor(color)
#         # Fill and line
#         ax.fill_between(x, y_smooth, color='white')
#         # ax.plot(x, y_smooth, color='white', linewidth=1)

#         # Labels and axes
#         pct = float(proportions.get(cid, 0.0))
#         ax.set_ylabel(f"Cluster {cid} | {pct:.1f}%", fontsize=8)
#         ax.set_ylim(0, 1.05)

#         # Determine x-range
#         x_min = np.nanmin(x) if start is None else start
#         x_max = np.nanmax(x) if end is None else end
#         ax.set_xlim(x_min, x_max)

#         # Grid
#         ax.grid(True, axis='x', linestyle='--', linewidth=0.5, color='gray')

#     # Shared x formatting
#     for ax in axes[:-1]:
#         ax.tick_params(labelbottom=False)

#     axes[-1].xaxis.set_major_locator(MultipleLocator(100))
#     axes[-1].xaxis.set_major_formatter(FuncFormatter(lambda v, pos: f"{int(v)}"))
#     axes[-1].get_xaxis().get_offset_text().set_visible(False)

#     # Global labels
#     fig.text(0.95, 0.5, 'Methylation level', va='center', ha='center', rotation=-90, fontsize=10)
#     fig.text(
#         0.05,
#         0.5,
#         "Cluster number | Cluster proportion (ordered by cluster mean methylation)",
#         fontsize=10,
#         rotation=90,
#         va='center',
#         ha='center'
#     )
#     fig.text(0.5, 0.04, "Genomic coordinate", va='center', ha='center', fontsize=10)

#     plt.subplots_adjust(hspace=0.05, top=0.92)
#     plt.setp(axes[-1].get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')
#     axes[-1].tick_params(axis='x', labelsize=8)
#     fig.subplots_adjust(bottom=0.18)

#     plt.show()
#     return fig, axes


In [None]:
def compute_bulk_centroids(
    df: pd.DataFrame,
    use_cov_for_proportion: bool = True
):
    """
    Compute per-cluster centroids across the entire bulk table.

    Input df columns: ['cluster','C_pos','meth','cov'].
    Returns:
      P_df: clusters × bins matrix (mean methylation per bin)
      cov_df: clusters × bins matrix (coverage per bin)
      meta_df: per-cluster summary with ['cluster_id','n_rows_or_cov','proportion']
      positions: sorted integer bin coordinates (C_pos)
    """
    required = {'cluster','C_pos','meth','cov'}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"Dataframe missing columns: {missing}")

    d = df.copy()
    d['C_pos'] = pd.to_numeric(d['C_pos'], errors='coerce')
    d['meth']  = pd.to_numeric(d['meth'],  errors='coerce')
    d['cov']   = pd.to_numeric(d['cov'],   errors='coerce')
    d = d.dropna(subset=['cluster','C_pos','meth'])

    # Aggregate duplicates at the same position within cluster
    agg = (
        d.groupby(['cluster','C_pos'], as_index=False)
         .agg(meth_mean=('meth','mean'),
              cov_agg=('cov','first'))  # use 'sum' if cov varies per row
    )

    # Pivot to wide: cluster × position
    P_df = agg.pivot(index='cluster', columns='C_pos', values='meth_mean')
    cov_df = agg.pivot(index='cluster', columns='C_pos', values='cov_agg')

    # Sort columns and convert to ints
    positions = np.array(sorted(P_df.columns.astype(int)), dtype=int)
    P_df = P_df.reindex(columns=positions)
    cov_df = cov_df.reindex(columns=positions)

    # Proportions (height ratios)
    g = d.groupby('cluster')
    if use_cov_for_proportion:
        cov_first = g['cov'].first()
        cov_var   = g['cov'].var()
        cov_tot   = g['cov'].sum() if np.nanmax(cov_var.fillna(0)) > 0 else cov_first
        total_cov = float(cov_tot.sum()) if cov_tot.sum() is not None else 0.0
        proportions = (cov_tot / (total_cov if total_cov > 0 else 1.0)) * 100.0
        n_metric = cov_tot.astype(float)
        n_label = 'cov_total'
    else:
        counts = g.size()
        proportions = (counts / counts.sum()) * 100.0
        n_metric = counts.astype(float)
        n_label = 'n_rows'

    meta_rows = []
    for cid in P_df.index:
        meta_rows.append({
            'cluster_id': cid if isinstance(cid, (int, np.integer)) else str(cid),
            n_label: float(n_metric.get(cid, np.nan)),
            'proportion': float(proportions.get(cid, 0.0))
        })
    meta_df = pd.DataFrame(meta_rows).sort_values('cluster_id').reset_index(drop=True)

    return P_df, cov_df, meta_df, positions


In [None]:
def compute_gene_centroids(
    df_reads: pd.DataFrame,
    df_map: pd.DataFrame,  # ['readid', 'locus_cluster', 'locus_cluster_color']
    gene_tss: str,
    read_id_col: str = 'readid',
    cluster_col: str = 'locus_cluster',
    color_col: str = 'locus_cluster_color',
    extra_meta_cols: list | None = None  # e.g., ['gene_tss', 'group']
):
    """
    Compute per-cluster average methylation per bin (and coverage) for one gene.

    Returns
    -------
    profiles_df : pd.DataFrame
        Mean methylation per cluster × bin.
    coverage_df : pd.DataFrame
        Number of reads with non-NaN methylation per cluster × bin.
    meta_df : pd.DataFrame
        Metadata for each cluster: [gene_tss, cluster_id, n_reads, proportion, color].
    positions : np.ndarray
        Sorted numeric bin positions (integers).
    """
    # --- Ensure read IDs are present as a column ---
    if read_id_col not in df_reads.columns:
        df_w = df_reads.reset_index().rename(columns={'index': read_id_col})
    else:
        df_w = df_reads.copy()

    # --- Merge cluster/color mapping ---
    dfg = pd.merge(
        df_w,
        df_map[[read_id_col, cluster_col, color_col]],
        on=read_id_col,
        how='inner'
    )

    if dfg.empty:
        raise ValueError("No overlapping reads between df_reads and df_map.")

    # --- Identify metadata columns ---
    meta_cols = {read_id_col, cluster_col, color_col}
    if extra_meta_cols:
        meta_cols |= set(extra_meta_cols)

    # --- Select candidate bin columns ---
    candidate_cols = [c for c in dfg.columns if c not in meta_cols]

    # --- Keep only columns with numeric names AND numeric data ---
    def int_like(name):
        return (
            isinstance(name, (int, np.integer))
            or (isinstance(name, str) and name.strip().lstrip('-').isdigit())
        )

    bin_cols_num = []
    positions = []
    for c in candidate_cols:
        if int_like(c) and pd.api.types.is_numeric_dtype(dfg[c]):
            bin_cols_num.append(c)
            positions.append(int(c) if not isinstance(c, (int, np.integer)) else int(c))

    if not bin_cols_num:
        raise ValueError(
            "No numeric bin columns found. Check df_reads columns and extra_meta_cols."
        )

    positions = np.array(sorted(positions), dtype=int)

    # --- Build clean numeric-column version ---
    col_map = {
        c: (int(c) if not isinstance(c, (int, np.integer)) else int(c))
        for c in bin_cols_num
    }
    dfg_bins = dfg[[cluster_col, color_col] + bin_cols_num].rename(columns=col_map)
    dfg_bins = dfg_bins[[cluster_col, color_col] + positions.tolist()]

    # --- Group by cluster and compute stats ---
    profiles = {}
    coverage = {}
    colors = {}
    counts_per_cluster = {}

    for cid, sub in dfg_bins.groupby(cluster_col):
        vals = sub[positions]
        profiles[cid] = vals.mean(axis=0, skipna=True)
        coverage[cid] = vals.notna().sum(axis=0)
        colors[cid] = (
            sub[color_col].iloc[0]
            if color_col in sub.columns and not sub[color_col].isna().all()
            else None
        )
        counts_per_cluster[cid] = int(len(sub))

    profiles_df = pd.DataFrame(profiles).T.reindex(columns=positions)
    coverage_df = pd.DataFrame(coverage).T.reindex(columns=positions)

    profiles_df.index.name = 'cluster'
    profiles_df.columns.name = 'C_pos'

    # --- Build meta table ---
    total_reads = sum(counts_per_cluster.values())
    meta_rows = []
    for cid in profiles_df.index:
        n_reads = counts_per_cluster.get(cid, 0)
        pct = (n_reads / total_reads * 100.0) if total_reads > 0 else 0.0
        meta_rows.append({
            'gene_tss': gene_tss,
            'cluster_id': int(cid),
            'n_reads': int(n_reads),
            'proportion': float(pct),
            'cluster_color': colors.get(cid)
        })

    meta_df = (
        pd.DataFrame(meta_rows)
        .sort_values('cluster_id')
        .reset_index(drop=True)
    )


    return profiles_df, coverage_df, meta_df, positions


In [None]:
def plot_centroids_with_shading(
    P_df: pd.DataFrame,                 # index = cluster_id, columns = positions (int), values = mean methylation
    positions: np.ndarray,              # sorted integer positions
    meta_df: pd.DataFrame,              # must include ['cluster_id', 'proportion']; optional ['n_reads', 'cluster_color']
    coverage_df: pd.DataFrame | None = None,  # same shape as P_df; per-bin contributing read counts
    hex_colors=None,
    smooth_sigma: float = 0.5,
    start: int | None = None,
    end: int | None = None,
    title: str = "Average DNA Methylation Profiles per Cluster",
    show_pol2_line: bool = True,
    shade_missing: bool = True,          # enable/disable shading
    missingness_threshold: float = 0.7,  # shade sites where missingness >= threshold
    bin_width: int | None = None         # if None, inferred from median bin spacing
):
    # --- Clusters and proportions ---
    clusters = list(P_df.index)
    n_clusters = len(clusters)

    if 'proportion' in meta_df.columns:
        prop_map = {int(row['cluster_id']): float(row['proportion'])
                    for _, row in meta_df.iterrows()}
    else:
        prop_map = {int(cid): 100.0 / max(n_clusters, 1) for cid in clusters}

    # --- Colors ---
    if 'cluster_color' in meta_df.columns and meta_df['cluster_color'].notna().any():
        color_map = {int(row['cluster_id']): row['cluster_color']
                     for _, row in meta_df.iterrows()}
    else:
        import matplotlib.cm as cm
        if hex_colors is None:
            cmap = cm.get_cmap('tab20', n_clusters)
            hex_colors = [cm.colors.to_hex(cmap(i)) for i in range(n_clusters)]
        color_map = {cid: hex_colors[i % len(hex_colors)]
                     for i, cid in enumerate(clusters)}

    # --- Height ratios from proportions ---
    height_ratios = [max(1e-3, float(prop_map.get(int(cid), 0.0)))
                     for cid in clusters]

    # --- Infer bin width ---
    if bin_width is None and len(positions) > 1:
        diffs = np.diff(positions)
        bin_width = int(np.median(diffs)) if len(diffs) else 1
    if bin_width is None:
        bin_width = 1

    # --- Estimate per-cluster size for missingness ---
    n_reads_map = {}
    if 'n_reads' in meta_df.columns:
        n_reads_map = {int(row['cluster_id']): int(row['n_reads'])
                       for _, row in meta_df.iterrows()}
    elif coverage_df is not None:
        for cid in clusters:
            try:
                n_reads_map[int(cid)] = int(np.nanmax(
                    coverage_df.loc[cid].to_numpy(float)))
            except Exception:
                n_reads_map[int(cid)] = 0

    # --- Create figure ---
    fig, axes = plt.subplots(
        n_clusters, 1, figsize=(16, 18), sharex=True,
        gridspec_kw={'height_ratios': height_ratios}
    )
    if n_clusters == 1:
        axes = [axes]
    fig.suptitle(title, fontsize=14, y=0.97)

    # --- Plot each cluster ---
    for i, cid in enumerate(clusters):
        ax = axes[i]
        y = P_df.loc[cid].to_numpy(float)
        x = positions
        y_smooth = (gaussian_filter1d(y, sigma=float(smooth_sigma))
                    if smooth_sigma and smooth_sigma > 0 else y)

        # --- Shading of high missingness sites ---
        if shade_missing and coverage_df is not None:
            cs = int(n_reads_map.get(int(cid), 0))
            if cs > 0:
                cov = coverage_df.loc[cid].to_numpy(float)
                miss_frac = 1.0 - (cov / cs)
                mask = miss_frac >= float(missingness_threshold)

                runs = []
                run_start = None
                prev = None
                for p, ok in zip(x, mask):
                    if ok and run_start is None:
                        run_start = p
                        prev = p
                    elif ok:
                        prev = p
                    elif (not ok) and run_start is not None:
                        runs.append((run_start, prev))
                        run_start = None
                if run_start is not None:
                    runs.append((run_start, prev))

                for a, b in runs:
                    ax.axvspan(a - bin_width / 2, b + bin_width / 2,
                               color="#969696", alpha=0.7, zorder=0.9)

        # --- Plot profile ---
        ax.set_facecolor(color_map.get(cid, '#cccccc'))
        ax.fill_between(x, y_smooth, color='white')

        # --- Axes, labels ---
        pct = float(prop_map.get(int(cid), 0.0))
        ax.set_ylabel(f"Cluster {cid} | {pct:.1f}%", fontsize=8)
        ax.set_ylim(0, 1.05)

        x_min = np.nanmin(x) if start is None else start
        x_max = np.nanmax(x) if end is None else end
        ax.set_xlim(x_min, x_max)

        ax.grid(True, axis='x', linestyle='--',
                linewidth=0.5, color='gray')
        if show_pol2_line:
            ax.axvline(x=0, color="#C80028", linestyle="-", linewidth=2)

    # --- Shared x-axis formatting ---
    for ax in axes[:-1]:
        ax.tick_params(labelbottom=False)
    axes[-1].xaxis.set_major_locator(MultipleLocator(100))
    axes[-1].xaxis.set_major_formatter(FuncFormatter(lambda v, pos: f"{int(v)}"))
    axes[-1].get_xaxis().get_offset_text().set_visible(False)

    # --- Global labels ---
    fig.text(0.95, 0.5, 'Methylation level',
             va='center', ha='center', rotation=-90, fontsize=10)
    fig.text(0.05, 0.5, "Cluster number | Cluster proportion",
             fontsize=10, rotation=90, va='center', ha='center')
    fig.text(0.5, 0.04, "Genomic coordinate",
             va='center', ha='center', fontsize=10)

    plt.subplots_adjust(hspace=0.05, top=0.92)
    plt.setp(axes[-1].get_xticklabels(), rotation=45,
             ha='right', rotation_mode='anchor')
    axes[-1].tick_params(axis='x', labelsize=8)
    fig.subplots_adjust(bottom=0.18)

    plt.show()
    return fig, axes


In [None]:
# -------------------------------------------------------------------
# 1. Masked cosine similarity with NaN handling and column alignment
# -------------------------------------------------------------------
def masked_cosine_cross(
    G_df: pd.DataFrame,
    P_df: pd.DataFrame,
    min_overlap: int = 10,
    center_rows: bool = True,
    weights=None  # optional: pd.Series indexed by columns (positions), or array-like
) -> np.ndarray:
    """
    Compute m × p cosine similarities between rows of G and rows of P,
    masking NaNs per pair, aligning columns by intersection, and requiring
    at least `min_overlap` shared non-NaN bins.

    Parameters
    ----------
    G_df, P_df : pd.DataFrame
        DataFrames with columns = positions (bins). Columns may differ;
        they will be aligned by intersection.
    weights : pd.Series or array-like, optional
        Per-bin reliability weights. If provided, columns are scaled by sqrt(weights).

    Returns
    -------
    S : np.ndarray
        (m × p) cosine similarity matrix.
    """

    # ---- 1) Coerce column labels to int and sort ----
    def _coerce_sort(df):
        out = df.copy()
        col_series = pd.Series(out.columns, dtype=str).str.strip()
        col_numeric = pd.to_numeric(col_series, errors='coerce')
        if col_numeric.isna().any():
            bad = col_series[col_numeric.isna()].tolist()
            raise ValueError(f"Non-numeric bin columns found: {bad}")
        out.columns = col_numeric.astype(int)
        if out.columns.duplicated().any():
            out = out.groupby(level=0, axis=1).mean()
        return out.sort_index(axis=1)

    Gc = _coerce_sort(G_df)
    Pc = _coerce_sort(P_df)

    # ---- 2) Align to intersection of positions ----
    common = np.intersect1d(Gc.columns.values, Pc.columns.values)
    if common.size == 0:
        raise ValueError("No overlapping bins between G and P.")

    Gc = Gc.reindex(columns=common)
    Pc = Pc.reindex(columns=common)

    # ---- 3) Apply weights (scale columns by sqrt(weights)) ----
    if weights is not None:
        if isinstance(weights, pd.Series):
            sw = np.sqrt(np.maximum(weights.reindex(common).to_numpy(float), 0.0))
        else:
            w = np.asarray(weights, float)
            if w.shape[0] != common.size:
                try:
                    w_series = pd.Series(w, index=P_df.columns)
                    sw = np.sqrt(np.maximum(w_series.reindex(common).to_numpy(float), 0.0))
                except Exception:
                    raise ValueError(f"weights length {w.shape[0]} must equal n_bins {common.size}.")
            else:
                sw = np.sqrt(np.maximum(w, 0.0))
        Gc = Gc.mul(sw, axis=1)
        Pc = Pc.mul(sw, axis=1)

    # ---- 4) Compute masked cosine over aligned matrices ----
    G = Gc.to_numpy(float)
    P = Pc.to_numpy(float)
    m, d = G.shape
    p = P.shape[0]
    S = np.zeros((m, p), float)

    G_mask = np.isfinite(G)
    P_mask = np.isfinite(P)

    for i in range(m):
        gi, mi = G[i], G_mask[i]
        for j in range(p):
            pj, mj = P[j], P_mask[j]
            ov = mi & mj
            n = int(ov.sum())
            if n < int(min_overlap):
                S[i, j] = 0.0
                continue
            vi, vj = gi[ov], pj[ov]
            if center_rows:
                vi -= vi.mean()
                vj -= vj.mean()
            denom = np.linalg.norm(vi) * np.linalg.norm(vj)
            S[i, j] = float(np.dot(vi, vj) / denom) if denom > 0 else 0.0

    return S


# -------------------------------------------------------------------
# 2. Utility functions for Pearson-based similarity
# -------------------------------------------------------------------
def _zscore_rows(M, eps: float = 1e-6) -> np.ndarray:
    M = np.asarray(M, float)
    mu = np.nanmean(M, axis=1, keepdims=True)
    sd = np.nanstd(M, axis=1, keepdims=True)
    sd = np.where(sd < eps, eps, sd)
    return (M - mu) / sd


def _pairwise_pearson(G: np.ndarray, P: np.ndarray) -> np.ndarray:
    G = np.asarray(G, float)
    P = np.asarray(P, float)
    m, d = G.shape
    p = P.shape[0]
    S = np.zeros((m, p), float)

    for i in range(m):
        gi = G[i]
        for j in range(p):
            pj = P[j]
            mask = np.isfinite(gi) & np.isfinite(pj)
            if mask.sum() < 2:
                S[i, j] = 0.0
                continue
            gi_m, pj_m = gi[mask], pj[mask]
            gi_sd, pj_sd = gi_m.std(), pj_m.std()
            gi_m = (gi_m - gi_m.mean()) / (gi_sd if gi_sd > 0 else 1.0)
            pj_m = (pj_m - pj_m.mean()) / (pj_sd if pj_sd > 0 else 1.0)
            S[i, j] = float(np.dot(gi_m, pj_m) / mask.sum())
    return S


# -------------------------------------------------------------------
# 3. Compute similarity and distance matrices
# -------------------------------------------------------------------
def compute_similarity_matrix(
    G,
    P,
    method: str = 'pearson',
    use_abs: bool = False,
    center_rows: bool = True,
    weights=None,
    min_overlap: int = 10
):
    """
    Compute similarity (and corresponding distance) matrix between rows of G and P.

    Parameters
    ----------
    G, P : array-like or pd.DataFrame
        Matrices or DataFrames representing profiles (rows=clusters, cols=bins).
        If DataFrames, columns must be numeric (positions).
    method : {'pearson'}, default='pearson'
        Similarity metric to use. Currently only Pearson is implemented here.
    use_abs : bool, default=False
        Take absolute value of similarity if True.
    center_rows : bool, default=True
        Whether to mean-center rows before computing similarity.
    weights : array-like, optional
        Optional positional weights (not used in current implementation).
    min_overlap : int, default=10
        Minimum number of overlapping bins required for valid correlation.

    Returns
    -------
    S : np.ndarray
        Similarity matrix (rows=G, cols=P).
    D : np.ndarray
        Distance matrix (1 - S, clipped at 0).
    prot_ids : list
        Prototype identifiers (row labels of P if DataFrame, else range).
    """

    if method == 'pearson':
        # ---- Coerce and align DataFrames to common bins ----
        if isinstance(G, pd.DataFrame) and isinstance(P, pd.DataFrame):

            def _coerce_sort(df):
                out = df.copy()
                cols = pd.Series(out.columns, dtype=str).str.strip()
                nums = pd.to_numeric(cols, errors='coerce')
                if nums.isna().any():
                    bad = cols[nums.isna()].tolist()
                    raise ValueError(f"Non-numeric bin columns: {bad}")
                out.columns = nums.astype(int)
                if out.columns.duplicated().any():
                    out = out.groupby(level=0, axis=1).mean()
                return out.sort_index(axis=1)

            Gc = _coerce_sort(G)
            Pc = _coerce_sort(P)
            common = np.intersect1d(Gc.columns.values, Pc.columns.values)
            if common.size == 0:
                raise ValueError("No overlapping bins between G and P.")

            G_arr = Gc.reindex(columns=common).to_numpy(float)
            P_arr = Pc.reindex(columns=common).to_numpy(float)

        else:
            G_arr = np.asarray(G, float)
            P_arr = np.asarray(P, float)

        # ---- Compute Pearson correlation ----
        Zg = _zscore_rows(G_arr)
        Zp = _zscore_rows(P_arr)
        S = _pairwise_pearson(Zg, Zp)  # masks NaNs per pair
        S = np.nan_to_num(S, nan=0.0)

        if use_abs:
            S = np.abs(S)

        D = 1.0 - S
        D[D < 0] = 0.0

        # ---- Preserve prototype IDs ----
        prot_ids = list(P.index) if isinstance(P, pd.DataFrame) else list(range(P_arr.shape[0]))

        return S, D, prot_ids
    
    elif method == 'cosine':
        if not isinstance(G, pd.DataFrame) or not isinstance(P, pd.DataFrame):
            raise ValueError("For masked cosine, pass G and P as DataFrames (columns = positions).")

        S = masked_cosine_cross(
            G_df=G,
            P_df=P,
            min_overlap=min_overlap,
            center_rows=center_rows,
            weights=weights,
        )

        if use_abs:
            S = np.abs(S)
        D = 1.0 - S
        D[D < 0] = 0.0
        prot_ids = list(P.index)
        return S, D, prot_ids

    else:
        raise ValueError("method must be 'pearson' or 'cosine'.")


# -------------------------------------------------------------------
# 4. Assign gene-level clusters to consensus clusters
# -------------------------------------------------------------------
def assign_gene_clusters_to_consensus(
    G,
    meta: pd.DataFrame,
    P,
    method: str = 'pearson',
    use_abs: bool = False,
    center_rows: bool = True,
    weights=None,
    min_overlap: int = 10,
    min_similarity: float | None = None,
    allow_unassigned: bool = True,
    capacity_per_type: int = 1,
):
    """
    Assign gene clusters (rows in G) to consensus clusters (rows in P)
    using pairwise similarity + Hungarian assignment.
    """

    S, D_base, prot_ids = compute_similarity_matrix(
        G=G,
        P=P,
        method=method,
        use_abs=use_abs,
        center_rows=center_rows,
        weights=weights,
        min_overlap=min_overlap,
    )

    meta_df = meta.copy()
    if 'size' not in meta_df.columns:
        meta_df['size'] = np.nan

    assignments = []
    for gene_id, row_idx in meta_df.groupby('gene_tss').groups.items():
        row_idx = list(row_idx)
        D = D_base[row_idx, :]

        if capacity_per_type > 1:
            P_rep = np.repeat(np.arange(D.shape[1]), repeats=int(capacity_per_type))
            D_aug = np.tile(D, reps=(1, int(capacity_per_type)))
        else:
            P_rep = np.arange(D.shape[1])
            D_aug = D

        if allow_unassigned:
            penalty = (1.0 - float(min_similarity)) if min_similarity is not None else float(np.nanmax(D_aug) + 0.05)
            dummy = np.full((D_aug.shape[0], D_aug.shape[0]), penalty, dtype=float)
            D_final = np.hstack([D_aug, dummy])
            col_map = np.concatenate([P_rep, -np.ones(dummy.shape[1], dtype=int)])
        else:
            D_final = D_aug
            col_map = P_rep

        r_idx, c_idx = linear_sum_assignment(D_final)

        for r, c in zip(r_idx, c_idx):
            cons_col = col_map[c]
            best_cost = D_final[r, c]
            sim = float(1.0 - best_cost)
            row_costs = D_final[r]
            mask = np.ones_like(row_costs, dtype=bool)
            mask[c] = False
            second_best_cost = np.min(row_costs[mask]) if mask.any() else np.nan
            second_best_sim = float(1.0 - second_best_cost) if np.isfinite(second_best_cost) else np.nan
            margin = float(second_best_cost - best_cost) if np.isfinite(second_best_cost) else np.nan
            cons_id = None if cons_col < 0 else prot_ids[int(cons_col)]

            assignments.append({
                'gene_tss': gene_id,
                'cluster_id': meta_df.iloc[r]['cluster_id'],
                'size': meta_df.iloc[r]['size'],
                'consensus_id': cons_id,
                'similarity': sim,
                'second_best_similarity': second_best_sim,
                'margin': margin,
            })

    assign_df = pd.DataFrame(assignments)
    return assign_df, S, D_base


In [None]:
def plot_gene_mapping_heatmap(
    assign_df: pd.DataFrame,
    S: np.ndarray,                  # full similarity matrix returned by assign_gene_clusters_to_consensus
    meta: pd.DataFrame,             # meta passed to assign (aligned to rows of G)
    P,                              # bulk prototypes (DataFrame for cosine, array for pearson)
    gene_id: str,
    sort_rows: bool = True,         # sort by assigned prototype then by max similarity
    cmap: str = "viridis",
    vmin: float = 0.0,
    vmax: float = 1.0,
):
    """
    Plot a heatmap of similarities (rows = gene clusters, cols = bulk prototypes)
    for a single gene and overlay markers for the Hungarian assignment.

    Notes
    -----
    - S must be the full similarity matrix returned by `assign_gene_clusters_to_consensus`.
      Rows of S must align with rows of `meta` (i.e., meta.loc[i] corresponds to S[i]).
    - P can be a DataFrame (then prot IDs are P.index) or a numpy array (then prot IDs are 0..P.shape[0]-1).
    """
    # Prototype IDs (columns of S)
    if isinstance(P, pd.DataFrame):
        prot_ids = list(P.index)
    else:
        prot_ids = list(range(P.shape[0]))

    # Row indices for the selected gene (rows of S)
    idx_rows = meta.index[meta["gene_tss"] == gene_id].tolist()
    if not idx_rows:
        raise ValueError(f"No rows in meta for gene_id={gene_id}.")

    # Subset similarity and labels (rows in the same order as S)
    S_sub = S[idx_rows, :]
    row_labels = meta.loc[idx_rows, "cluster_id"].astype(int).tolist()
    col_labels = prot_ids

    # Row ordering / sorting
    if sort_rows:
        ass_sub = assign_df[assign_df["gene_tss"] == gene_id].copy()
        # map cluster_id -> assigned consensus_id
        ass_map = {int(r["cluster_id"]): r["consensus_id"] for _, r in ass_sub.iterrows()}
        sort_keys = []
        for i, cid in enumerate(row_labels):
            assigned = ass_map.get(int(cid), None)
            j_ass = col_labels.index(assigned) if (assigned is not None and assigned in col_labels) else -1
            # use -max_similarity so that higher similarity sorts earlier within same assigned group
            sort_keys.append((j_ass, -float(np.nanmax(S_sub[i])) if S_sub.size else 0.0))
        order = sorted(range(len(sort_keys)), key=lambda ii: sort_keys[ii])
        S_plot = S_sub[order, :]
        row_labels_plot = [row_labels[i] for i in order]
    else:
        S_plot = S_sub
        row_labels_plot = row_labels

    # Create heatmap
    fig, ax = plt.subplots(
        figsize=(1.0 + 0.4 * len(col_labels), 0.6 + 0.4 * len(row_labels_plot))
    )
    sns.heatmap(
        S_plot,
        xticklabels=col_labels,
        yticklabels=row_labels_plot,
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        cbar_kws={"label": "similarity"},
        ax=ax,
    )
    ax.set_title(f"Mapping of {gene_id} clusters to bulk prototypes")
    ax.set_xlabel("bulk consensus ID")
    ax.set_ylabel("gene cluster ID")

    # Overlay assignment markers
    ass_sub = assign_df[assign_df["gene_tss"] == gene_id]
    # Build CID -> row index in the plotted order
    row_index_map = {cid: i for i, cid in enumerate(row_labels_plot)}
    for _, r in ass_sub.iterrows():
        cid = int(r["cluster_id"])
        cons_id = r["consensus_id"]
        if cons_id is None or cons_id not in col_labels:
            continue
        i = row_index_map.get(cid, None)
        j = col_labels.index(cons_id)
        if i is None:
            continue
        # place marker at the center of the heatmap cell
        ax.scatter(j + 0.5, i + 0.5, s=60, facecolors="none", edgecolors="white", linewidths=1.5)

    plt.tight_layout()
    return fig, ax


def plot_assignment_summary(assign_df: pd.DataFrame):
    """
    Bar chart: number of gene clusters assigned to each bulk prototype (across all genes).
    Unassigned clusters (consensus_id == NaN) are dropped.
    """
    df = assign_df.dropna(subset=["consensus_id"])
    counts = df.groupby("consensus_id").size().sort_values(ascending=False)
    fig, ax = plt.subplots(figsize=(8, 3))
    counts.plot(kind="bar", ax=ax, color="#4c72b0")
    ax.set_ylabel("# gene clusters assigned")
    ax.set_xlabel("bulk consensus ID")
    ax.set_title("Assignment summary across genes")
    plt.tight_layout()
    return fig, ax


def recolor_meta_by_bulk(assign_df, meta_df, bulk_color_map, gene_id):
    """
    Return a copy of meta_df for `gene_id` with 'cluster_color' replaced by the assigned
    bulk prototype color (falls back to original color if unassigned).
    """
    m = meta_df.copy()
    ass_sub = assign_df[assign_df["gene_tss"] == gene_id]
    ass_map = {int(r["cluster_id"]): r["consensus_id"] for _, r in ass_sub.iterrows()}

    colors = []
    for _, r in m.iterrows():
        cid = int(r["cluster_id"])
        cons = ass_map.get(cid, None)
        col = bulk_color_map.get(cons, r.get("cluster_color", None))
        colors.append(col)
    m = m.assign(cluster_color=colors)
    return m


def plot_assignment_confidence(assign_df: pd.DataFrame, gene_id: str | None = None):
    """
    Scatter plot of best similarity vs. margin (second_best_similarity - best_similarity).
    Small margin => ambiguous assignment.
    """
    df = assign_df.copy()
    if gene_id is not None:
        df = df[df["gene_tss"] == gene_id]
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.scatter(df["similarity"], df["margin"], s=30, alpha=0.7)
    ax.set_xlabel("best similarity")
    ax.set_ylabel("margin (second-best − best)")
    ax.set_title("Assignment confidence" + (f" ({gene_id})" if gene_id else ""))
    ax.grid(alpha=0.3)
    plt.tight_layout()
    return fig, ax


##### Running code on the actual dataset to preprocess + plot the reads

In [None]:
#### BINNING THE BULK DATA TO BA ABLE TO DO REASSIGNMENT 

d= bin_then_matrix(
    df_bulk,
    indexes=['cluster'],
    bin_size=50
)

 # ---- Mean methylation per (cluster, bin) ----
meth_agg = (
        d.groupby(['cluster', 'bin'], as_index=False)
        .agg(meth_mean=('meth', 'mean'))
    )

# Take unique cov per cluster (assumes cov constant per cluster)
cov_cluster = d.groupby('cluster')['cov'].first()
cov_rows = []
for cid, sub in meth_agg.groupby('cluster'):
    cov_val = float(cov_cluster.get(cid, np.nan))
    for b in sub['bin'].unique():
        cov_rows.append({'cluster': cid, 'bin': b, 'cov_bin': cov_val})
cov_agg = pd.DataFrame(cov_rows)

# ---- Merge meth and coverage ----
agg = pd.merge(meth_agg, cov_agg, on=['cluster', 'bin'], how='left')

# ---- Pivot to wide profiles ----
profiles_df_bulk_binned = agg.pivot(index='cluster', columns='bin', values='meth_mean')
coverage_df_bulk_binned = agg.pivot(index='cluster', columns='bin', values='cov_bin')

    # ---- Sort columns and set names ----
positions_bulk_binned = np.array(sorted(profiles_df_bulk_binned.columns.astype(int)), dtype=int)
profiles_df_bulk_binned = profiles_df_bulk_binned.reindex(columns=positions_bulk_binned)
coverage_df_bulk_binned = coverage_df_bulk_binned.reindex(columns=positions_bulk_binned)

profiles_df_bulk_binned.index.name = 'cluster'
profiles_df_bulk_binned.columns = pd.Index(positions_bulk_binned, name='C_pos')
coverage_df_bulk_binned.index.name = 'cluster'
coverage_df_bulk_binned.columns = pd.Index(positions_bulk_binned, name='C_pos')


n_metric = cov_cluster.astype(float)
total_cov= float(cov_cluster.sum()) if cov_cluster.sum() is not None else 0.0
proportions = (cov_cluster / (total_cov if total_cov > 0 else 1.0)) * 100.0

meta_rows = []
for cid in profiles_df_bulk_binned.index:
        meta_rows.append({
            'cluster_id': cid if isinstance(cid, (int, np.integer)) else str(cid),
            'cov_total': float(n_metric.get(cid, np.nan)),
            'proportion': float(proportions.get(cid, 0.0))
        })
meta_df_bulk_binned = pd.DataFrame(meta_rows).sort_values('cluster_id').reset_index(drop=True)

profiles_df_bulk_binned, 
meta_df_bulk_binned
coverage_df_bulk_binned, 
positions_bulk_binned


In [None]:
##RUNNING THE CODE ON THE ACTUAL DATASET -- applying preprocess_dataframe
columns_to_drop=['read_strand', 'N', 'n_meth_motif',
       'perc_meth_in_motif','peak_cov', 'peak_meth', 'C_in_motif',
       'n', 'read_meth_mean', 'motif_id', 'gene2', 'bound', 'cov', 'peakid', 'meth_C_in_motif','read_meth_C_in_motif']

df_proc = preprocess_dataframe(df, columns_to_drop)

df_proc.head()

In [None]:
# WHOLE PREPROCESSED DF THAT GETS BINNED UNDER METHYLATION MATRIX (i need it to flter the genes that we want to work on)

df_whole_matrix= process_into_matrix(df_proc,gene=None,bin_size= 50)
df_whole_matrix.head()

In [None]:
# WHOLE PREPROCESSED DF THAT IS UNBINNED UNDER METHYLATION MATRIX (i need it plot the average methylation profile unbinned)

df_whole_matrix_unbinned= process_into_matrix(df_proc,gene=None,bin_size= 1)
df_whole_matrix_unbinned.head(50)

In [None]:
# FILTERING AND KEEPING ONLY THE GENES THAT ARE SUITABLE (HAVE ENOUGH READS OVERLAPPING AND LONG ENOUGH)

df_filtered, filtered_regions = filter_reads_per_gene_middle_bin_name(df_whole_matrix, middle_bin_name=None, min_reads =50, min_bins =40,require_middle_bin=True)
print(f'THe shape of the filtered dataframe is of {df_filtered.shape}')
print(f'There is {df_filtered.index.get_level_values("gene_tss").nunique()} different genes in this dataframe')

df_filtered.head()

In [None]:
#GETTING THE FILTERED GENES NAMES AND SUB DF
gene_names, gene_df = get_genes_list(df_filtered)
gene_names_unbinned, gene_df_unbinned = get_genes_list(df_whole_matrix_unbinned)

In [None]:
print(gene_names)
print(gene_df)

In [None]:
#SELECTNG A GENE OF THE LIST OF GENES (for i=19, SETD1A_30958295_30958295)
gene= gene_names[7]
df_gene= gene_df[7]

gene_test= gene_names[19]
df_gene_test= gene_df[19]

df_gene_unbinned = gene_df_unbinned[19]


print(gene)
print(gene_test)

In [None]:
a= df_proc[df_proc['gene_tss']== 'CDC45_19505053_19505053'].head()
print(a['C_start'].min())
print(a['C_end'].max())
# df_gene.head(15)
# df_gene.head(15)

In [None]:
a=df_proc[df_proc['gene_tss']== 'SETD1A_30958295_30958295'].head()
print(a['C_start'].min())
print(a['C_end'].max())
# df_gene.head(15)
# df_gene.head(15)

In [None]:
#PLotting for a gene - no color, whole gene, all reads

df_to_plot= preprocess_long_for_plot(df_proc)
plot_reads_long(df_to_plot,filters={'gene_tss':gene}, facet_by=None, color_by=None, hex_colors=None)

df_to_plot.head(50)

In [None]:
#PLotting for a gene - no color, whole gene, all reads

plot_reads_long(df_to_plot,filters={'gene_tss':gene_test}, facet_by=None, color_by=None, hex_colors=None)

In [None]:
# plotting for a gene per group - no. color, all reads, faceting by group
plot_reads_long(df_to_plot,filters={'gene_tss':gene}, facet_by= 'group', color_by=None, hex_colors=None)

##### Running the clustering on the actual data

In [None]:
# ADAPTATIVE SWEEP (unused)
# summary_best, best = sweep_k_for_target_fast(
#     df=df_gene,                # reads × bins matrix
#     k_list=[5, 10, 15, 20, 25, 30],
#     target=6,
#     transform='arcsine',       # or 'logit'
#     metric='euclidean',
#     scaling=False,
#     pca_or_not=True,
#     n_pcs=None,                # None => 95% variance retained
#     kernel_type='laplacian',
#     nan_threshold=0.7,
#     nan_method='drop',
#     res_init=1.0,
#     res_min=0.05,
#     res_max=2.0,
#     up_factor=1.6,
#     max_iter=10,
#     silhouette_sample=None,    # e.g., 2000 to speed up on large N
#     seed=42
# )

# # --- Use best solution ---
# clusters = best['clusters']
# chosen_k = best['k']
# chosen_res = best['resolution']
# X_pca = best['X_pca']
# g = best['graph']

# print(f"Chosen k/res: {chosen_k} / {chosen_res:.3f} | n_clusters: {best['n_clusters']}")


In [None]:
## testing if there are NaN in the distance matrices for different metrics
# for metric in ['cosine', 'correlation', 'euclidean-logit', 'jsd']:
#     D = pairwise_distance_profiles(X_new, metric)
#     print(metric, np.isnan(D).any())

In [None]:
# ## COMPARING THE DIFFERENT DISTANCE METRICS FOR THE LEIDEN CLUSTERING ON A SAME PLOT
# metrics_to_compare = [
#     'cosine',
#     'correlation',
#     'euclidean-logit',
#     'jsd'
#     ]

# plot_leiden_umap_comparison(
#     X= X_new,
#     hex_colors=hex_colors,
#     metrics=metrics_to_compare,
#     n_neighbors=15,
#     resolution=0.8,
# )


In [None]:
# ## PLOTING MULTIPLE UMAPS FOR ALL THE GENES IN THE DF_FILTERED
# plot_multiple_umaps(df_filtered,gene_names,leiden_neighbors= 10, umap_neighbors=15, resolution= 0.8, min_dist=0.1, leiden_metric = 'cosine', umap_metric="cosine")

In [None]:
# #df_clean is the df without the reads that have a missing value
# df_clean, X, partition, clusters = clustering(df_gene, n_neighbors=10, knn_metric='euclidean', leiden_resolution=0.8)
# df_clean_2, X_2, partition_2, clusters_2 = clustering_improved(df_gene, n_neighbors=15, 
#                                                             #    n_pcs=5,
#                                                                metric='euclidean', transform='logit',leiden_resolution=0.8, seed=42)

# plotting_the_clustering(
#     X, 
#     clusters, 
#     n_neighbors=15, 
#     min_dist=0.1, 
#     metric='euclidean')

# plotting_the_clustering_2(
#     X_2,                 # raw profiles or PCA scores
#     clusters_2,          # Leiden cluster labels, shape (n_samples,)
#     n_neighbors=15,
#     min_dist=0.1,
#     metric='euclidean',  # 'euclidean', 'cosine', 'correlation', 'jaccard', etc.
#     transform=None,      # None | 'logit' | 'sqrt' (apply if X are raw proportions)
#     n_pcs=None,          # e.g., 30 if you want PCA before UMAP; None = use X as-is
#     seed=42
# )

# embedding, fig = plot_umap(X, clusters, n_neighbors=15, min_dist=0.1, metric='euclidean', n_pcs=None, standardize=False, seed=42, palette= hex_colors, title="UMAP - clustering with euclidean metric (impoved)")
# embedding_2, fig_2 = plot_umap(X_2, clusters_2, n_neighbors=15, min_dist=0.1, metric='euclidean', n_pcs=None, standardize=False, seed=42, title="UMAP - clustering with euclidean metric (impoved)")

In [None]:
# df_PCA_euclidean, X_PCA_euclidean, partition_PCA_euclidean, clusters_PCA_euclidean,  = alternate_clustering(
#     df_gene_test,
#     apply_filter= True,
#     min_bin_non_nan_frac = 0.5,
#     min_read_non_nan_frac = 0.5,
#     transform = 'logit',  # None | 'arcsine' | 'logit'
#     min_overlap = 40,
#     k = 10,
#     positive_only = True,
#     shrink_c = None,
#     mutual=False,
#     leiden_resolution= 1,
#     seed= 42
# )

# n_clusters = len(set(clusters_PCA_euclidean)) 
# print("clusters:", n_clusters)

# # embedding_PCA_euclidean, fig_PCA_euclidean = plot_umap(
# #     X_PCA_euclidean,
# #     clusters_PCA_euclidean,
# #     n_neighbors=25,
# #     min_dist=0.1,
# #     metric='euclidean',      # for Pipeline 2, use 'euclidean'
# #     transform=None,          # None | 'logit' | 'arcsine' (use only if X are raw proportions)
# #     n_pcs=None,              # set if X are raw features; None if X are already PCs
# #     standardize=False,       # True if using raw features without PCA
# #     seed=42,
# #     palette=hex_colors,            # optional list/array or matplotlib colormap name
# #     title="PCA + Euclidean (improved)"
# # )

# read_dict_PCA_euclidean = dict_id_cluster_color(df_PCA_euclidean, clusters_PCA_euclidean, hex_colors)

# df_dict_PCA_euclidean = dict_to_df(read_dict_PCA_euclidean)
# df_test_PCA_euclidean= merge(df_proc, df_dict_PCA_euclidean)

# # Long format for plotting
# df_test_plot_PCA_euclidean = preprocess_long_for_plot(df_test_PCA_euclidean, include_locus_cluster=True)

# # Plot
# plot_reads_long(
#     df_test_plot_PCA_euclidean,
#     filters={
#         # "gene_tss": gene, 
#         'group':['BT474_mV_high_Untreated',
#         'BT474_mV_low_Untreated',
#         'BT474_mV_Untreated_Unsorted']},
#     # facet_by='group',
#     color_by="locus_cluster",
#     hex_colors=hex_colors
# )

# df_only_interesting_groups_PCA_euclidean = df_PCA_euclidean[df_PCA_euclidean.index.get_level_values('group').isin([
#     'BT474_mV_high_Untreated',
#     'BT474_mV_low_Untreated',
#     'BT474_mV_Untreated_Unsorted'
# ])]

# df_only_interesting_groups_PCA_euclidean = merge(df_only_interesting_groups_PCA_euclidean, df_dict_PCA_euclidean)
# clusters_of_interest_PCA_euclidean = df_only_interesting_groups_PCA_euclidean['locus_cluster'].tolist() 
# summary(df_only_interesting_groups_PCA_euclidean, clusters_of_interest_PCA_euclidean)


# start_PCA_euclidean, end_PCA_euclidean, center_coord_PCA_euclidean= start_end_center(df_test_plot_PCA_euclidean)

# plot_avg_methylation_profile(
#     df= df_PCA_euclidean,
#     df_dict=df_dict_PCA_euclidean,
#     start=start_PCA_euclidean,
#     end=end_PCA_euclidean,
#     center_coord=center_coord_PCA_euclidean,
#     read_dict=read_dict_PCA_euclidean
# )

In [None]:
# df_PCA_euclidean, X_PCA_euclidean, partition_PCA_euclidean, clusters_PCA_euclidean, metrics_PCA_euclidean = clustering_final(
#     df_gene,
#     n_neighbors=15,
#     nan_threshold=0.7,
#     nan_method='drop',
#     scaling=False,
#     pca_or_not=True,
#     n_pcs=30,           # None = no PCA; int for #PCs; float in (0,1) for variance ratio
#     metric='euclidean',          # 'euclidean' or 'cosine' 
#     transform='logit',            # 'none', 'logit', or 'arcsine'
#     kernel_type='laplacian',
#     leiden_resolution=0.8,
#     seed=42
# )

# embedding_PCA_euclidean, fig_PCA_euclidean = plot_umap(
#     X_PCA_euclidean,
#     clusters_PCA_euclidean,
#     n_neighbors=25,
#     min_dist=0.1,
#     metric='euclidean',      # for Pipeline 2, use 'euclidean'
#     transform=None,          # None | 'logit' | 'arcsine' (use only if X are raw proportions)
#     n_pcs=None,              # set if X are raw features; None if X are already PCs
#     standardize=False,       # True if using raw features without PCA
#     seed=42,
#     palette=hex_colors,            # optional list/array or matplotlib colormap name
#     title="PCA + Euclidean (improved)"
# )

# print(metrics_PCA_euclidean)

# multi_sil_PCA_euclidean= silhouettes_multi(X_PCA_euclidean, clusters_PCA_euclidean, metric_main='euclidean', extra_metrics=['cosine', 'correlation', 'jaccard'])

# print(f"Silhouette (PCA + Euclidean): {metrics_PCA_euclidean['silhouette']:.3f}")
# print(f"Silhouette (PCA + Euclidean) - Cosine: {multi_sil_PCA_euclidean['silhouette_cosine']:.3f}")
# print(f"Silhouette (PCA + Euclidean) - Correlation: {multi_sil_PCA_euclidean['silhouette_correlation']:.3f}")
# print(f"Silhouette (PCA + Euclidean) - Jaccard: {multi_sil_PCA_euclidean['silhouette_jaccard']:.3f}")


# read_dict_PCA_euclidean = dict_id_cluster_color(df_PCA_euclidean, clusters_PCA_euclidean, hex_colors)

# df_dict_PCA_euclidean = dict_to_df(read_dict_PCA_euclidean)
# df_test_PCA_euclidean= merge(df_proc, df_dict_PCA_euclidean)

# # Long format for plotting
# df_test_plot_PCA_euclidean = preprocess_long_for_plot(df_test_PCA_euclidean, include_locus_cluster=True)

# # Plot
# plot_reads_long(
#     df_test_plot_PCA_euclidean,
#     filters={
#         # "gene_tss": gene, 
#         'group':['BT474_mV_high_Untreated',
#         'BT474_mV_low_Untreated',
#         'BT474_mV_Untreated_Unsorted']},
#     # facet_by='group',
#     color_by="locus_cluster",
#     hex_colors=hex_colors
# )

# df_only_interesting_groups_PCA_euclidean = df_PCA_euclidean[df_PCA_euclidean.index.get_level_values('group').isin([
#     'BT474_mV_high_Untreated',
#     'BT474_mV_low_Untreated',
#     'BT474_mV_Untreated_Unsorted'
# ])]

# df_only_interesting_groups_PCA_euclidean = merge(df_only_interesting_groups_PCA_euclidean, df_dict_PCA_euclidean)
# clusters_of_interest_PCA_euclidean = df_only_interesting_groups_PCA_euclidean['locus_cluster'].tolist() 
# summary(df_only_interesting_groups_PCA_euclidean, clusters_of_interest_PCA_euclidean)


# start_PCA_euclidean, end_PCA_euclidean, center_coord_PCA_euclidean= start_end_center(df_test_plot_PCA_euclidean)

# plot_avg_methylation_profile(
#     df= df_PCA_euclidean,
#     df_dict=df_dict_PCA_euclidean,
#     start=start_PCA_euclidean,
#     end=end_PCA_euclidean,
#     center_coord=center_coord_PCA_euclidean,
#     read_dict=read_dict_PCA_euclidean
# )

In [None]:
df_PCA_euclidean, X_PCA_euclidean, partition_PCA_euclidean, clusters_PCA_euclidean, metrics_PCA_euclidean = clustering_final(
        df_gene,
        n_neighbors=15,
        nan_threshold=0.7,
        nan_method='drop',
        scaling=False,
        pca_or_not=True,
        n_pcs=None,           # None = no PCA; int for #PCs; float in (0,1) for variance ratio
        metric='euclidean',          # 'euclidean' or 'cosine'
        transform='arcsine',            # 'none', 'logit', or 'arcsine'
        kernel_type='laplacian',
        leiden_resolution=0.8,
        seed=42
    )

print(clusters_PCA_euclidean)

embedding_PCA_euclidean, fig_PCA_euclidean = plot_umap(
        X_PCA_euclidean,
        clusters_PCA_euclidean,
        n_neighbors=25,
        min_dist=0.1,
        metric='euclidean',      # for Pipeline 2, use 'euclidean'
        transform=None,          # None | 'logit' | 'arcsine' (use only if X are raw proportions)
        n_pcs=None,              # set if X are raw features; None if X are already PCs
        standardize=False,       # True if using raw features without PCA
        seed=42,
        palette=hex_colors,            # optional list/array or matplotlib colormap name
        title="PCA + euclidean + laplacian"
    )


df_PCA_euclidean.head()

read_dict_PCA_euclidean = dict_id_cluster_color(df_PCA_euclidean, clusters_PCA_euclidean, hex_colors)

df_dict_PCA_euclidean = dict_to_df(read_dict_PCA_euclidean)
df_test_PCA_euclidean= merge(df_proc, df_dict_PCA_euclidean)

    # Long format for plotting
df_test_plot_PCA_euclidean = preprocess_long_for_plot(df_test_PCA_euclidean,
                                                          include_locus_cluster=True,
                                                          filter_outliers= True,
                                                          max_span_bp=None,     #i'm letting the filtering be based on the span_quantile
                                                          span_quantile=0.99,
                                                          require_center_inside=True,
                                                          min_cpg=1,
                                                          cpg_window_bp=5000
                                                          )

    # Plot
plot_reads_long(
        df_test_plot_PCA_euclidean,
        filters={
            # "gene_tss": gene, 
            'group':['BT474_mV_high_Untreated',
            'BT474_mV_low_Untreated',
            'BT474_mV_Untreated_Unsorted']},
        # facet_by='group',
        color_by="locus_cluster",
        gene=gene,
        hex_colors=hex_colors
    )

df_only_interesting_groups_PCA_euclidean = df_PCA_euclidean[df_PCA_euclidean.index.get_level_values('group').isin([
        'BT474_mV_high_Untreated',
        'BT474_mV_low_Untreated',
        'BT474_mV_Untreated_Unsorted'
    ])]

df_only_interesting_groups_PCA_euclidean = merge(df_only_interesting_groups_PCA_euclidean, df_dict_PCA_euclidean)
clusters_of_interest_PCA_euclidean = df_only_interesting_groups_PCA_euclidean['locus_cluster'].tolist() 
summary(df_only_interesting_groups_PCA_euclidean, clusters_of_interest_PCA_euclidean, gene=gene)


# start_PCA_euclidean, end_PCA_euclidean, center_coord_PCA_euclidean= start_end_center(df_test_plot_PCA_euclidean)

profiles_df_PCA_euclidean, coverage_df_PCA_euclidean, meta_df_PCA_euclidean, positions_PCA_euclidean= compute_gene_centroids(
                                                                    df_PCA_euclidean,
                                                                    df_dict_PCA_euclidean,
                                                                    gene_tss=gene,
                                                                   )

plot_centroids_with_shading(
    profiles_df_PCA_euclidean,
    positions_PCA_euclidean,
    meta_df_PCA_euclidean,
    coverage_df_PCA_euclidean,
    hex_colors,
    smooth_sigma = 0,
    title=gene + 'Average DNA Methylation Profiles per Cluster',
    missingness_threshold=0.3,
)

# position_cols, meth_arrays = plot_avg_methylation_profile(
#         df= df_PCA_euclidean,
#         df_dict=df_dict_PCA_euclidean,
#         start=start_PCA_euclidean,
#         end=end_PCA_euclidean,
#         center_coord=center_coord_PCA_euclidean,
#         read_dict=read_dict_PCA_euclidean,
#         gene= gene
#     )

In [None]:
df_PCA_euclidean, X_PCA_euclidean, partition_PCA_euclidean, clusters_PCA_euclidean, metrics_PCA_euclidean = clustering_final(
        df_gene,
        n_neighbors=15,
        nan_threshold=0.7,
        nan_method='drop',
        scaling=False,
        pca_or_not=True,
        n_pcs=None,           # None = no PCA; int for #PCs; float in (0,1) for variance ratio
        metric='euclidean',          # 'euclidean' or 'cosine'
        transform='arcsine',            # 'none', 'logit', or 'arcsine'
        kernel_type='laplacian',
        leiden_resolution=0.8,
        seed=42
    )

print(clusters_PCA_euclidean)

embedding_PCA_euclidean, fig_PCA_euclidean = plot_umap(
        X_PCA_euclidean,
        clusters_PCA_euclidean,
        n_neighbors=25,
        min_dist=0.1,
        metric='euclidean',      # for Pipeline 2, use 'euclidean'
        transform=None,          # None | 'logit' | 'arcsine' (use only if X are raw proportions)
        n_pcs=None,              # set if X are raw features; None if X are already PCs
        standardize=False,       # True if using raw features without PCA
        seed=42,
        palette=hex_colors,            # optional list/array or matplotlib colormap name
        title="PCA + euclidean + laplacian"
    )


df_PCA_euclidean.head()

read_dict_PCA_euclidean = dict_id_cluster_color(df_PCA_euclidean, clusters_PCA_euclidean, hex_colors)

df_dict_PCA_euclidean = dict_to_df(read_dict_PCA_euclidean)
df_test_PCA_euclidean= merge(df_proc, df_dict_PCA_euclidean)

    # Long format for plotting
df_test_plot_PCA_euclidean = preprocess_long_for_plot(df_test_PCA_euclidean,
                                                          include_locus_cluster=True,
                                                          filter_outliers= True,
                                                          max_span_bp=None,     #i'm letting the filtering be based on the span_quantile
                                                          span_quantile=0.99,
                                                          require_center_inside=True,
                                                          min_cpg=1,
                                                          cpg_window_bp=5000
                                                          )

    # Plot
plot_reads_long(
        df_test_plot_PCA_euclidean,
        filters={
            # "gene_tss": gene, 
            'group':['BT474_mV_high_Untreated',
            'BT474_mV_low_Untreated',
            'BT474_mV_Untreated_Unsorted']},
        # facet_by='group',
        color_by="locus_cluster",
        gene=gene,
        hex_colors=hex_colors
    )

df_only_interesting_groups_PCA_euclidean = df_PCA_euclidean[df_PCA_euclidean.index.get_level_values('group').isin([
        'BT474_mV_high_Untreated',
        'BT474_mV_low_Untreated',
        'BT474_mV_Untreated_Unsorted'
    ])]

df_only_interesting_groups_PCA_euclidean = merge(df_only_interesting_groups_PCA_euclidean, df_dict_PCA_euclidean)
clusters_of_interest_PCA_euclidean = df_only_interesting_groups_PCA_euclidean['locus_cluster'].tolist() 
summary(df_only_interesting_groups_PCA_euclidean, clusters_of_interest_PCA_euclidean, gene=gene)


start_PCA_euclidean, end_PCA_euclidean, center_coord_PCA_euclidean= start_end_center(df_test_plot_PCA_euclidean)

# profiles_df_PCA_euclidean, coverage_df_PCA_euclidean, meta_df_PCA_euclidean, positions_PCA_euclidean= compute_gene_centroids(
#                                                                     df_PCA_euclidean,
#                                                                     df_dict_PCA_euclidean,
#                                                                     gene_tss=gene,
#                                                                    )

# plot_centroids_with_shading(
#     profiles_df_PCA_euclidean,
#     positions_PCA_euclidean,
#     meta_df_PCA_euclidean,
#     coverage_df_PCA_euclidean,
#     hex_colors,
#     smooth_sigma = 0,
#     title=gene + 'Average DNA Methylation Profiles per Cluster',
#     missingness_threshold=0.3,
# )

position_cols, meth_arrays = plot_avg_methylation_profile(
        df= df_PCA_euclidean,
        df_dict=df_dict_PCA_euclidean,
        start=start_PCA_euclidean,
        end=end_PCA_euclidean,
        center_coord=center_coord_PCA_euclidean,
        read_dict=read_dict_PCA_euclidean,
        gene= gene
    )

In [None]:
## TESTING OUT THE CLUSTERING IMPROVED OLD ON PCA AND EUCLIDEAN 

# df_PCA_euclidean, X_PCA_euclidean, partition_PCA_euclidean, clusters_PCA_euclidean, = clustering_improved_old(
#     df_gene,
#     n_neighbors=15,
#     nan_threshold=0.7,
#     nan_method='drop',
#     scaling=False,
#     n_pcs=30,           # None = no PCA; int for #PCs; float in (0,1) for variance ratio
#     metric='euclidean',          # 'euclidean' or 'cosine' 
#     transform='logit',            # 'none', 'logit', or 'arcsine'
#     leiden_resolution=0.8,
#     seed=42
# )

# embedding_PCA_euclidean, fig_PCA_euclidean = plot_umap(
#     X_PCA_euclidean,
#     clusters_PCA_euclidean,
#     n_neighbors=25,
#     min_dist=0.1,
#     metric='euclidean',      # for Pipeline 2, use 'euclidean'
#     transform=None,          # None | 'logit' | 'arcsine' (use only if X are raw proportions)
#     n_pcs=None,              # set if X are raw features; None if X are already PCs
#     standardize=False,       # True if using raw features without PCA
#     seed=42,
#     palette=hex_colors,            # optional list/array or matplotlib colormap name
#     title="PCA + Euclidean (improved)"
# )

# multi_sil_PCA_euclidean= silhouettes_multi(X_PCA_euclidean, clusters_PCA_euclidean, metric_main='euclidean', extra_metrics=['cosine', 'correlation', 'jaccard'])

# print(f"Silhouette (PCA + Euclidean): {metrics_PCA_euclidean['silhouette']:.3f}")
# print(f"Silhouette (PCA + Euclidean) - Cosine: {multi_sil_PCA_euclidean['silhouette_cosine']:.3f}")
# print(f"Silhouette (PCA + Euclidean) - Correlation: {multi_sil_PCA_euclidean['silhouette_correlation']:.3f}")
# print(f"Silhouette (PCA + Euclidean) - Jaccard: {multi_sil_PCA_euclidean['silhouette_jaccard']:.3f}")


# read_dict_PCA_euclidean = dict_id_cluster_color(df_PCA_euclidean, clusters_PCA_euclidean, hex_colors)

# df_dict_PCA_euclidean = dict_to_df(read_dict_PCA_euclidean)
# df_test_PCA_euclidean= merge(df_proc, df_dict_PCA_euclidean)

# # Long format for plotting
# df_test_plot_PCA_euclidean = preprocess_long_for_plot(df_test_PCA_euclidean, include_locus_cluster=True)

# # Plot
# plot_reads_long(
#     df_test_plot_PCA_euclidean,
#     filters={
#         # "gene_tss": gene, 
#         'group':['BT474_mV_high_Untreated',
#         'BT474_mV_low_Untreated',
#         'BT474_mV_Untreated_Unsorted']},
#     # facet_by='group',
#     color_by="locus_cluster",
#     hex_colors=hex_colors
# )

# df_only_interesting_groups_PCA_euclidean = df_PCA_euclidean[df_PCA_euclidean.index.get_level_values('group').isin([
#     'BT474_mV_high_Untreated',
#     'BT474_mV_low_Untreated',
#     'BT474_mV_Untreated_Unsorted'
# ])]

# df_only_interesting_groups_PCA_euclidean = merge(df_only_interesting_groups_PCA_euclidean, df_dict_PCA_euclidean)
# clusters_of_interest_PCA_euclidean = df_only_interesting_groups_PCA_euclidean['locus_cluster'].tolist() 
# summary(df_only_interesting_groups_PCA_euclidean, clusters_of_interest_PCA_euclidean)


# start_PCA_euclidean, end_PCA_euclidean, center_coord_PCA_euclidean= start_end_center(df_test_plot_PCA_euclidean)

# plot_avg_methylation_profile(
#     df= df_PCA_euclidean,
#     df_dict=df_dict_PCA_euclidean,
#     start=start_PCA_euclidean,
#     end=end_PCA_euclidean,
#     center_coord=center_coord_PCA_euclidean,
#     read_dict=read_dict_PCA_euclidean
# )

In [None]:
df_PCA_euclidean, X_PCA_euclidean_unbinned, partition_PCA_euclidean_unbinned, clusters_PCA_euclidean_unbinned, metrics_PCA_euclidean_unbinned = clustering_final(
        df_gene,
        n_neighbors=9,
        nan_threshold=0.1,
        nan_method='drop',
        scaling=False,
        pca_or_not=True,
        n_pcs=None,           # None = no PCA; int for #PCs; float in (0,1) for variance ratio
        metric='euclidean',          # 'euclidean' or 'cosine'
        transform='arcsine',            # 'none', 'logit', or 'arcsine'
        kernel_type='laplacian',
        leiden_resolution=0.8,
        seed=42
    )

read_dict_PCA_euclidean_unbinned = dict_id_cluster_color(df_gene_unbinned, clusters_PCA_euclidean_unbinned, hex_colors)

df_dict_PCA_euclidean_unbinned = dict_to_df(read_dict_PCA_euclidean_unbinned)
df_test_PCA_euclidean_unbinned= merge(df_proc, df_dict_PCA_euclidean_unbinned)

df_test_plot_PCA_euclidean_unbinned = preprocess_long_for_plot(df_test_PCA_euclidean_unbinned,
                                                          include_locus_cluster=True,
                                                          filter_outliers= True,
                                                          max_span_bp=None,     #i'm letting the filtering be based on the span_quantile
                                                          span_quantile=0.99,
                                                          require_center_inside=True,
                                                          min_cpg=1,
                                                          cpg_window_bp=5000
                                                          )

profiles_df_unbinned, coverage_df_unbinned_euclidean, meta_df_unbinned_euclidean, positions_unbinned_euclidean= compute_gene_centroids(
                                                                    df_gene_unbinned,
                                                                    df_dict_PCA_euclidean_unbinned,
                                                                    gene_tss='SETD1A_30958295_30958295',
                                                                   )

plot_centroids_with_shading(
    profiles_df_unbinned,
    positions_unbinned_euclidean,
    meta_df_unbinned_euclidean,
    coverage_df_unbinned_euclidean,
    hex_colors,
    smooth_sigma = 0,
    title=gene +'Average DNA Methylation Profiles per Cluster',
    missingness_threshold=0.5,
)


In [None]:
clusters_corr, g_corr, df_corr = alternate_clustering(
    df_gene,
    apply_filter = False,
    min_bin_non_nan_frac = 0.7,
    min_read_non_nan_frac = 1,
    transform = None,  # None | 'arcsine' | 'logit'
    min_overlap = 40,
    k = 40,
    positive_only = True,
    shrink_c = 0.0,
    mutual = True,
    leiden_resolution= 0.6 ,
    seed = 42
)

# embedding_corr, fig_corr = plot_umap(
#     X_cosine,
#     clusters_cosine,
#     n_neighbors=25,
#     min_dist=0.1,
#     metric='cosine',      # for Pipeline 2, use 'euclidean'
#     transform=None,          # None | 'logit' | 'arcsine' (use only if X are raw proportions)
#     n_pcs=None,              # set if X are raw features; None if X are already PCs
#     standardize=False,       # True if using raw features without PCA
#     seed=42,
#     palette=hex_colors,            # optional list/array or matplotlib colormap name
#     title="Cosine (improved)"
# )

# multi_sil_cosine= silhouettes_multi(X_cosine, clusters_cosine, metric_main='cosine', extra_metrics=['euclidean', 'correlation', 'jaccard'])

# print(f"Silhouette (Cosine): {metrics_cosine['silhouette']:.3f}")
# print(f"Silhouette (Cosine) - Euclidean: {multi_sil_cosine['silhouette_cosine']:.3f}")
# print(f"Silhouette (Cosine) - Correlation: {multi_sil_cosine['silhouette_correlation']:.3f}")
# print(f"Silhouette (Cosine) - Jaccard: {multi_sil_cosine['silhouette_jaccard']:.3f}")


read_dict_corr = dict_id_cluster_color(df_corr, clusters_corr, hex_colors)

df_dict_corr = dict_to_df(read_dict_corr)
df_test_corr= merge(df_proc, df_dict_corr)

# Long format for plotting
df_test_plot_corr = preprocess_long_for_plot(df_test_corr, include_locus_cluster=True)

# Plot
plot_reads_long(
    df_test_plot_corr,
    filters={
        # "gene_tss": gene, 
        'group':['BT474_mV_high_Untreated',
        'BT474_mV_low_Untreated',
        'BT474_mV_Untreated_Unsorted']},
    facet_by='group',
    color_by="locus_cluster",
    hex_colors=hex_colors
)

df_only_interesting_groups_corr = df_corr[df_corr.index.get_level_values('group').isin([
    'BT474_mV_high_Untreated',
    'BT474_mV_low_Untreated',
    'BT474_mV_Untreated_Unsorted'
])]

df_only_interesting_groups_corr = merge(df_only_interesting_groups_corr, df_dict_corr)
clusters_of_interest_corr = df_only_interesting_groups_corr['locus_cluster'].tolist() 
summary(df_only_interesting_groups_corr, clusters_of_interest_corr)


start_corr, end_corr, center_coord_corr= start_end_center(df_test_plot_corr)

plot_avg_methylation_profile(
    df= df_corr,
    df_dict=df_dict_corr,
    start=start_corr,
    end=end_corr,
    center_coord=center_coord_corr,
    read_dict=read_dict_corr
)

In [None]:
for i in range(len(gene_df)):
    gene = gene_names[i]
    df_gene = gene_df[i]

    print(gene)

    df_PCA_euclidean, X_PCA_euclidean, partition_PCA_euclidean, clusters_PCA_euclidean, metrics_PCA_euclidean = clustering_final(
        df_gene,
        n_neighbors=15,
        nan_threshold=0.7,
        nan_method='drop',
        scaling=False,
        pca_or_not=True,
        n_pcs=None,           # None = no PCA; int for #PCs; float in (0,1) for variance ratio
        metric='euclidean',          # 'euclidean' or 'cosine' 
        transform='arcsine',            # 'none', 'logit', or 'arcsine'
        kernel_type='laplacian',
        leiden_resolution=0.8,
        seed=42
    )

    embedding_PCA_euclidean, fig_PCA_euclidean = plot_umap(
        X_PCA_euclidean,
        clusters_PCA_euclidean,
        n_neighbors=25,
        min_dist=0.1,
        metric='euclidean',      
        transform=None,          # None | 'logit' | 'arcsine' (use only if X are raw proportions)
        n_pcs=None,              # set if X are raw features; None if X are already PCs
        standardize=False,       # True if using raw features without PCA
        seed=42,
        palette=hex_colors,            # optional list/array or matplotlib colormap name
        title="PCA + Euclidean + logit + Laplacian",
        gene=gene
    )

    print(metrics_PCA_euclidean)

    multi_sil_PCA_euclidean= silhouettes_multi(X_PCA_euclidean, clusters_PCA_euclidean, metric_main='euclidean', extra_metrics=['cosine', 'correlation', 'jaccard'])

    print(f"Silhouette (PCA + Euclidean): {metrics_PCA_euclidean['silhouette']:.3f}")
    print(f"Silhouette (PCA + Euclidean) - Cosine: {multi_sil_PCA_euclidean['silhouette_cosine']:.3f}")
    print(f"Silhouette (PCA + Euclidean) - Correlation: {multi_sil_PCA_euclidean['silhouette_correlation']:.3f}")
    print(f"Silhouette (PCA + Euclidean) - Jaccard: {multi_sil_PCA_euclidean['silhouette_jaccard']:.3f}")


    read_dict_PCA_euclidean = dict_id_cluster_color(df_PCA_euclidean, clusters_PCA_euclidean, hex_colors)

    df_dict_PCA_euclidean = dict_to_df(read_dict_PCA_euclidean)
    df_test_PCA_euclidean= merge(df_proc, df_dict_PCA_euclidean)

    # Long format for plotting
    df_test_plot_PCA_euclidean = preprocess_long_for_plot(df_test_PCA_euclidean,
                                                          include_locus_cluster=True,
                                                          filter_outliers= True,
                                                          max_span_bp=None,     #i'm letting the filtering be based on the span_quantile
                                                          span_quantile=0.99,
                                                          require_center_inside=True,
                                                          min_cpg=1,
                                                          cpg_window_bp=5000
                                                          )

    # Plot
    plot_reads_long(
        df_test_plot_PCA_euclidean,
        filters={
            # "gene_tss": gene, 
            'group':['BT474_mV_high_Untreated',
            'BT474_mV_low_Untreated',
            'BT474_mV_Untreated_Unsorted']},
        # facet_by='group',
        color_by="locus_cluster",
        gene=gene,
        hex_colors=hex_colors
    )

    df_only_interesting_groups_PCA_euclidean = df_PCA_euclidean[df_PCA_euclidean.index.get_level_values('group').isin([
        'BT474_mV_high_Untreated',
        'BT474_mV_low_Untreated',
        'BT474_mV_Untreated_Unsorted'
    ])]

    df_only_interesting_groups_PCA_euclidean = merge(df_only_interesting_groups_PCA_euclidean, df_dict_PCA_euclidean)
    clusters_of_interest_PCA_euclidean = df_only_interesting_groups_PCA_euclidean['locus_cluster'].tolist() 
    summary(df_only_interesting_groups_PCA_euclidean, clusters_of_interest_PCA_euclidean, gene=gene)

    profiles_df_PCA_euclidean, coverage_df_PCA_euclidean, meta_df_PCA_euclidean, positions_PCA_euclidean= compute_gene_centroids(
                                                                        df_PCA_euclidean,
                                                                        df_dict_PCA_euclidean,
                                                                        gene_tss=gene,
                                                                    )

    plot_centroids_with_shading(
        profiles_df_PCA_euclidean,
        positions_PCA_euclidean,
        meta_df_PCA_euclidean,
        coverage_df_PCA_euclidean,
        hex_colors,
        smooth_sigma = 0,
        title= gene +':Average DNA Methylation Profiles per Cluster - EUCLIDEAN',
        missingness_threshold=0.5,
    )


    # start_PCA_euclidean, end_PCA_euclidean, center_coord_PCA_euclidean= start_end_center(df_test_plot_PCA_euclidean)

    # plot_avg_methylation_profile(
    #     df= df_PCA_euclidean,
    #     df_dict=df_dict_PCA_euclidean,
    #     start=start_PCA_euclidean,
    #     end=end_PCA_euclidean,
    #     center_coord=center_coord_PCA_euclidean,
    #     read_dict=read_dict_PCA_euclidean,
    #     gene= gene
    # )

In [None]:
for i in range(len(gene_df)):
    gene = gene_names[i]
    df_gene = gene_df[i]

    print(gene)
    df_cosine, X_cosine, partition_cosine, clusters_cosine, metrics_cosine = clustering_final(
        df_gene,
        n_neighbors=15,
        nan_threshold=0.7,
        nan_method='drop',
        scaling=False,
        pca_or_not=False,
        n_pcs=None,           # None = no PCA; int for #PCs; float in (0,1) for variance ratio
        metric='cosine',          # 'euclidean' or 'cosine'
        transform='arcsine',            # 'none', 'logit', or 'arcsine'
        kernel_type='laplacian',
        leiden_resolution=0.8,
        seed=42
    )

    embedding_cosine, fig_cosine = plot_umap(
        X_cosine,
        clusters_cosine,
        n_neighbors=25,
        min_dist=0.1,
        metric='cosine',      # for Pipeline 2, use 'euclidean'
        transform=None,          # None | 'logit' | 'arcsine' (use only if X are raw proportions)
        n_pcs=None,              # set if X are raw features; None if X are already PCs
        standardize=False,       # True if using raw features without PCA
        seed=42,
        palette=hex_colors,            # optional list/array or matplotlib colormap name
        title="NO PCA + cosine + laplacian",
        gene=gene
    )

    print(metrics_cosine)

    multi_sil_cosine= silhouettes_multi(X_cosine, clusters_cosine, metric_main='cosine', extra_metrics=['euclidean', 'correlation', 'jaccard'])

    print(f"Silhouette (Cosine): {metrics_cosine['silhouette']:.3f}")
    print(f"Silhouette (Cosine) - Euclidean: {multi_sil_cosine['silhouette_euclidean']:.3f}")
    print(f"Silhouette (Cosine) - Correlation: {multi_sil_cosine['silhouette_correlation']:.3f}")
    print(f"Silhouette (Cosine) - Jaccard: {multi_sil_cosine['silhouette_jaccard']:.3f}")


    read_dict_cosine = dict_id_cluster_color(df_cosine, clusters_cosine, hex_colors)

    df_dict_cosine = dict_to_df(read_dict_cosine)
    df_test_cosine= merge(df_proc, df_dict_cosine)

    # Long format for plotting
    df_test_plot_cosine = preprocess_long_for_plot(df_test_cosine,
                                                    include_locus_cluster=True,
                                                    filter_outliers= True,
                                                    max_span_bp=None,     #i'm letting the filtering be based on the span_quantile
                                                    span_quantile=0.99,
                                                    require_center_inside=True,
                                                    min_cpg=1,
                                                    cpg_window_bp=5000
                                                    )

    # Plot
    plot_reads_long(
        df_test_plot_cosine,
        filters={
            # "gene_tss": gene, 
            'group':['BT474_mV_high_Untreated',
            'BT474_mV_low_Untreated',
            'BT474_mV_Untreated_Unsorted']},
        # facet_by='group',
        color_by="locus_cluster",
        gene= gene,
        hex_colors=hex_colors
    )

    df_only_interesting_groups_cosine = df_cosine[df_cosine.index.get_level_values('group').isin([
        'BT474_mV_high_Untreated',
        'BT474_mV_low_Untreated',
        'BT474_mV_Untreated_Unsorted'
    ])]

    df_only_interesting_groups_cosine = merge(df_only_interesting_groups_cosine, df_dict_cosine)
    clusters_of_interest_cosine = df_only_interesting_groups_cosine['locus_cluster'].tolist() 
    summary(df_only_interesting_groups_cosine, clusters_of_interest_cosine)

    profiles_df_cosine, coverage_df_cosine, meta_df_cosine, positions_cosine = compute_gene_centroids(
                                                                        df_cosine,
                                                                        df_dict_cosine,
                                                                        gene_tss=gene,
                                                                    )

    plot_centroids_with_shading(
        profiles_df_cosine,
        positions_cosine,
        meta_df_cosine,
        coverage_df_cosine,
        hex_colors,
        smooth_sigma = 0,
        title= gene + ' : Average DNA Methylation Profiles per Cluster - COSINE ',
        missingness_threshold=0.5,
    )

#     start_cosine, end_cosine, center_coord_cosine= start_end_center(df_test_plot_cosine)

#     plot_avg_methylation_profile(
#     df= df_cosine,
#     df_dict=df_dict_cosine,
#     start=start_cosine,
#     end=end_cosine,
#     center_coord=center_coord_cosine,
#     read_dict=read_dict_cosine,
#     gene=gene
# )

In [None]:
df_cosine, X_cosine, partition_cosine, clusters_cosine, metrics_cosine = clustering_final(
        df_gene,
        n_neighbors=15,
        nan_threshold=0.7,
        nan_method='drop',
        scaling=False,
        pca_or_not=False,
        n_pcs=None,           # None = no PCA; int for #PCs; float in (0,1) for variance ratio
        metric='cosine',          # 'euclidean' or 'cosine'
        transform='arcsine',            # 'none', 'logit', or 'arcsine'
        kernel_type='laplacian',
        leiden_resolution=0.8,
        seed=42
    )

embedding_cosine, fig_cosine = plot_umap(
        X_cosine,
        clusters_cosine,
        n_neighbors=25,
        min_dist=0.1,
        metric='cosine',      # for Pipeline 2, use 'euclidean'
        transform=None,          # None | 'logit' | 'arcsine' (use only if X are raw proportions)
        n_pcs=None,              # set if X are raw features; None if X are already PCs
        standardize=False,       # True if using raw features without PCA
        seed=42,
        palette=hex_colors,            # optional list/array or matplotlib colormap name
        title="NO PCA + cosine + laplacian",
        gene=gene
    )

print(metrics_cosine)

multi_sil_cosine= silhouettes_multi(X_cosine, clusters_cosine, metric_main='cosine', extra_metrics=['euclidean', 'correlation', 'jaccard'])

print(f"Silhouette (Cosine): {metrics_cosine['silhouette']:.3f}")
print(f"Silhouette (Cosine) - Euclidean: {multi_sil_cosine['silhouette_euclidean']:.3f}")
print(f"Silhouette (Cosine) - Correlation: {multi_sil_cosine['silhouette_correlation']:.3f}")
print(f"Silhouette (Cosine) - Jaccard: {multi_sil_cosine['silhouette_jaccard']:.3f}")


read_dict_cosine = dict_id_cluster_color(df_cosine, clusters_cosine, hex_colors)

df_dict_cosine = dict_to_df(read_dict_cosine)
df_test_cosine= merge(df_proc, df_dict_cosine)


In [None]:
df_gene.head()

In [None]:
df_cosine.head()

In [None]:
df_dict_cosine.head()

In [None]:
profiles_df_u, coverage_df_u, meta_df_u, positions_u = compute_gene_centroids(
    df_reads=df_gene_unbinned,        # reads × bins methylation matrix
    df_map=df_dict_cosine,        # mapping of readid → cluster + color
    gene_tss= 'SETD1A_30958295_30958295',
    extra_meta_cols=['gene_tss', 'group', 'cluster']  # add other non-bin columns here if present
)


plot_centroids_with_shading(
    profiles_df_u,
    positions_u,
    meta_df_u,
    coverage_df_u,
    hex_colors,
    smooth_sigma = 0,
    title='Average DNA Methylation Profiles per Cluster',
    missingness_threshold=0.5,
)
meta_df_u

In [None]:
profiles_df, coverage_df, meta_df, positions = compute_gene_centroids(
    df_reads=df_cosine,        # reads × bins methylation matrix
    df_map=df_dict_cosine,        # mapping of readid → cluster + color
    gene_tss= 'SETD1A_30958295_30958295',
    extra_meta_cols=['gene_tss', 'group', 'cluster']  # add other non-bin columns here if present
)


plot_centroids_with_shading(
    profiles_df,
    positions,
    meta_df,
    coverage_df,
    hex_colors,
    smooth_sigma = 0,
    title='Average DNA Methylation Profiles per Cluster',
    missingness_threshold=0.5,
)
meta_df

In [None]:
profiles_df.head()

In [None]:
# df_bulk has columns: ['cluster','C_pos','meth','cov'] for the whole dataset

P_df, coverage_bulk_df, meta_bulk_df, positions_bulk = compute_bulk_centroids(df_bulk, use_cov_for_proportion=True)

fig, axes = plot_centroids_with_shading(
    P_df=P_df,
    positions=positions_bulk,
    meta_df=meta_bulk_df,
    hex_colors=hex_colors,
    title="Average DNA Methylation Profiles per Cluster (Bulk)"
)

P_df.head(15)

In [None]:
meta_bulk_df

In [None]:
# df_bulk has columns: ['cluster','C_pos','meth','cov'] for the whole dataset

fig, axes = plot_centroids_with_shading(
    P_df=profiles_df_bulk_binned,
    positions=positions_bulk_binned,
    meta_df=meta_df_bulk_binned, #attention ici, j'utilise pas le meta donné par la fonction compute_gene_centroid
    hex_colors=hex_colors,
    title="Average DNA Methylation Profiles per Cluster (Bulk) (Binned)"
)

profiles_df_bulk_binned.head(15)

In [None]:
P_array = profiles_df_bulk_binned.to_numpy(float)

# Optional per-bin weights for cosine
weights = coverage_df.sum(axis=0).to_numpy(float) / (coverage_df.sum(axis=0).mean() + 1e-9)


In [None]:
assign_df, S, D = assign_gene_clusters_to_consensus(
    G=profiles_df,               # DataFrame (rows = gene clusters, cols = common positions)
    meta=meta_df,         # aligned to G rows
    P=profiles_df_bulk_binned,               # DataFrame (rows = bulk clusters, cols = common positions)
    method='cosine',
    center_rows=True,
    # weights=weights,  # optional pd.Series indexed by common positions
    min_overlap=40,
    min_similarity=0.6,
    allow_unassigned=True,
    capacity_per_type=1,
)

assign_df

In [None]:
# # Build G and meta for assignment
# G = profiles_df_unbinned

meta_for_assign = meta_df[['gene_tss', 'cluster_id', 'n_reads']].rename(columns={'n_reads': 'size'})

# # Optional cosine weights from coverage
# coverage_per_bin = coverage_df.sum(axis=0).to_numpy(float)
# weights = coverage_per_bin / (coverage_per_bin.mean() + 1e-9)

# print(G.shape) # (n_gene_clusters, n_bins_G)
# print(P_array.shape) # (n_prototypes, n_bins_P)
# print(weights.shape) # (n_bins_weights)



In [None]:
# # Assign to consensus prototypes

# assign_df, S, D = assign_gene_clusters_to_consensus(
#     G=G,                   # rows = clusters from some gene, columns must match positions
#     meta=meta_for_assign,    # DataFrame with ['gene_id','cluster_id','size']
#     P=P_df,
#     method='cosine',       # or 'pearson'
#     center_rows=False,
#     weights=weights,
#     min_similarity=0.4,
#     allow_unassigned=True,
#     capacity_per_type=1
# )


In [None]:

# 2) Heatmap for one gene
fig, ax = plot_gene_mapping_heatmap(
    assign_df=assign_df,
    S=S,
    meta=meta_for_assign,
    P=profiles_df_bulk_binned,
    gene_id='SETD1A_30958295_30958295',
    sort_rows=True
)

# 3) Assignment summary across genes
fig2, ax2 = plot_assignment_summary(assign_df)

# 4) Recolor meta for plotting centroids
bulk_color_map = {cons_id: '#hexcolor' for cons_id in P_df.index}  # build from bulk meta
meta_colored = recolor_meta_by_bulk(assign_df, meta_df_u, bulk_color_map, gene_id=gene)

# 5) Assignment confidence
# fig3, ax3 = plot_assignment_confidence(assign_df, gene_id=gene)


In [None]:
# Define pipeline configurations
configs = [
    # --- PCA + Euclidean (no scaling) ---
    dict(
        name='pca_euclidean',
        n_neighbors=15,
        nan_threshold=0.7,
        nan_method='drop',
        scaling=False,
        pca_or_not=True,
        n_pcs=None,  # keep 95% variance
        metric='euclidean',
        transform='logit',  # can also use 'arcsine' or 'none'
        kernel_type='laplacian',
        leiden_resolution=1.0,
        seed=42
    ),

    # --- No PCA + Cosine (shape-based) ---
    dict(
        name='cosine_no_pca',
        n_neighbors=15,
        nan_threshold=0.7,
        nan_method='drop',
        scaling=False,  # keep raw 0..1; cosine similarity handles scale
        pca_or_not=False,
        n_pcs=None,
        metric='cosine',
        transform='none',  # same transform for fairness
        kernel_type='laplacian',
        leiden_resolution=1.0,
        seed=42
    )
]

# Example input dictionary of gene-level dataframes
# e.g. genes_dict = {'GENE1': df_gene1, 'GENE2': df_gene2, ...}

# Run the clustering pipelines
results_df, outputs = run_pipelines_on_genes(gene_names, gene_df, configs)

print(results_df['pipeline'].value_counts())

for m in ['silhouette','calinski_harabasz','davies_bouldin','leiden_quality','modularity','n_clusters']: 
    if m in results_df.columns:
        print(m, results_df.pivot_table(index='gene', columns='pipeline', values=m, aggfunc='first').isna().all())
# Plot violin + scatter comparison of clustering metrics
# figs = plot_violin_scatter(results_df)

results_df.head(20)

fig, axes = plot_compare_pipelines_grid(results_df,
                                        pipelines_order=('pca_euclidean', 'cosine_no_pca'),
                                        metrics=('silhouette','calinski_harabasz','davies_bouldin','leiden_quality','modularity','n_clusters'),
                                        kind='violin', # or 'box' 
                                        show_points=True,
                                        connect_pairs=True )

In [None]:
df_filtered, X_filtered, partition_filtered, clusters_filtered, metrics_filtered = clustering_final(
        df_filtered,
        n_neighbors=15,
        nan_threshold=0.7,
        nan_method='drop',
        scaling=False,
        pca_or_not=True,
        n_pcs=None,           # None = no PCA; int for #PCs; float in (0,1) for variance ratio
        metric='euclidean',          # 'euclidean' or 'cosine' 
        transform='arcsine',            # 'none', 'logit', or 'arcsine'
        kernel_type='laplacian',
        leiden_resolution=1.1,
        seed=42
    )

embedding_filtered, fig_filtered = plot_umap(
        X_filtered,
        clusters_filtered,
        n_neighbors=25,
        min_dist=0.1,
        metric='euclidean',      
        transform=None,          # None | 'logit' | 'arcsine' (use only if X are raw proportions)
        n_pcs=None,              # set if X are raw features; None if X are already PCs
        standardize=False,       # True if using raw features without PCA
        seed=42,
        palette=hex_colors,            # optional list/array or matplotlib colormap name
        title="ALL FILTERED GENES PCA + Euclidean + logit + Laplacian",
        gene='ALL GENES'
    )

print(metrics_filtered)

multi_sil_filtered= silhouettes_multi(X_filtered, clusters_filtered, metric_main='euclidean', extra_metrics=['cosine', 'correlation', 'jaccard'])

print(f"Silhouette (PCA + Euclidean): {metrics_filtered['silhouette']:.3f}")
print(f"Silhouette (PCA + Euclidean) - Cosine: {multi_sil_filtered['silhouette_cosine']:.3f}")
print(f"Silhouette (PCA + Euclidean) - Correlation: {multi_sil_filtered['silhouette_correlation']:.3f}")
print(f"Silhouette (PCA + Euclidean) - Jaccard: {multi_sil_filtered['silhouette_jaccard']:.3f}")


read_dict_filtered = dict_id_cluster_color(df_filtered, clusters_filtered, hex_colors)

df_dict_filtered = dict_to_df(read_dict_filtered)
df_test_filtered= merge(df_proc, df_dict_filtered)

# Long format for plotting
df_test_plot_filtered = preprocess_long_for_plot(df_test_filtered,
                                                          include_locus_cluster=True,
                                                          filter_outliers= True,
                                                          max_span_bp=None,     #i'm letting the filtering be based on the span_quantile
                                                          span_quantile=0.99,
                                                          require_center_inside=True,
                                                          min_cpg=1,
                                                          cpg_window_bp=5000
                                                          )

    # Plot
plot_reads_long(
        df_test_plot_filtered,
        filters={
            # "gene_tss": gene, 
            'group':['BT474_mV_high_Untreated',
            'BT474_mV_low_Untreated',
            'BT474_mV_Untreated_Unsorted']},
        # facet_by='group',
        color_by="locus_cluster",
        gene='all genes',
        hex_colors=hex_colors
    )

df_only_interesting_groups_filtered = df_filtered[df_filtered.index.get_level_values('group').isin([
        'BT474_mV_high_Untreated',
        'BT474_mV_low_Untreated',
        'BT474_mV_Untreated_Unsorted'
    ])]

df_only_interesting_groups_filtered = merge(df_only_interesting_groups_filtered, df_dict_filtered)
clusters_of_interest_filtered = df_only_interesting_groups_filtered['locus_cluster'].tolist() 
summary(df_only_interesting_groups_filtered, clusters_of_interest_filtered, gene='all genes')


profiles_df_filtered, coverage_df_filtered, meta_df_filtered, positions_filtered= compute_gene_centroids(
                                                                    df_PCA_euclidean,
                                                                    df_dict_PCA_euclidean,
                                                                    gene_tss=gene,
                                                                   )

plot_centroids_with_shading(
    profiles_df_PCA_euclidean,
    positions_PCA_euclidean,
    meta_df_PCA_euclidean,
    coverage_df_PCA_euclidean,
    hex_colors,
    smooth_sigma = 0,
    title='Average DNA Methylation Profiles per Cluster',
    missingness_threshold=0.5,
)

# start_filtered, end_filtered, center_coord_filtered= start_end_center(df_test_plot_filtered)

# plot_avg_methylation_profile(
#         df= df_filtered,
#         df_dict=df_dict_filtered,
#         start=start_filtered,
#         end=end_filtered,
#         center_coord=center_coord_filtered,
#         read_dict=read_dict_filtered,
#         gene= 'all genes'
#     )

In [None]:
df_filtered_c, X_filtered_c, partition_filtered_c, clusters_filtered_c, metrics_filtered_c = clustering_final(
        df_filtered,
        n_neighbors=15,
        nan_threshold=0.7,
        nan_method='drop',
        scaling=False,
        pca_or_not=True,
        n_pcs=None,           # None = no PCA; int for #PCs; float in (0,1) for variance ratio
        metric='cosine',          # 'euclidean' or 'cosine' 
        transform='none',            # 'none', 'logit', or 'arcsine'
        kernel_type='laplacian',
        leiden_resolution=1.1,
        seed=42
    )

embedding_filtered_c, fig_filtered_c = plot_umap(
        X_filtered_c,
        clusters_filtered_c,
        n_neighbors=25,
        min_dist=0.1,
        metric='cosine',      
        transform=None,          # None | 'logit' | 'arcsine' (use only if X are raw proportions)
        n_pcs=None,              # set if X are raw features; None if X are already PCs
        standardize=False,       # True if using raw features without PCA
        seed=42,
        palette=hex_colors,            # optional list/array or matplotlib colormap name
        title="ALL FILTERED GENES Cosine",
        gene='ALL GENES'
    )

print(metrics_filtered_c)

multi_sil_filtered_c= silhouettes_multi(X_filtered_c, clusters_filtered_c, metric_main='cosine', extra_metrics=['euclidean', 'correlation', 'jaccard'])

print(f"Silhouette (Cosine): {metrics_filtered_c['silhouette']:.3f}")
print(f"Silhouette (Cosine) - Euclidean: {multi_sil_filtered_c['silhouette_euclidean']:.3f}")
print(f"Silhouette (Cosine) - Correlation: {multi_sil_filtered_c['silhouette_correlation']:.3f}")
print(f"Silhouette (Cosine) - Jaccard: {multi_sil_filtered_c['silhouette_jaccard']:.3f}")


read_dict_filtered_c = dict_id_cluster_color(df_filtered_c, clusters_filtered_c, hex_colors)

df_dict_filtered_c = dict_to_df(read_dict_filtered_c)
df_test_filtered_c= merge(df_proc, df_dict_filtered_c)

# Long format for plotting
df_test_plot_filtered_c = preprocess_long_for_plot(df_test_filtered_c,
                                                          include_locus_cluster=True,
                                                          filter_outliers= True,
                                                          max_span_bp=None,     #i'm letting the filtering be based on the span_quantile
                                                          span_quantile=0.99,
                                                          require_center_inside=True,
                                                          min_cpg=1,
                                                          cpg_window_bp=5000
                                                          )

    # Plot
plot_reads_long(
        df_test_plot_filtered_c,
        filters={
            # "gene_tss": gene, 
            'group':['BT474_mV_high_Untreated',
            'BT474_mV_low_Untreated',
            'BT474_mV_Untreated_Unsorted']},
        # facet_by='group',
        color_by="locus_cluster",
        gene='all genes - cosine',
        hex_colors=hex_colors
    )

df_only_interesting_groups_filtered_c = df_filtered_c[df_filtered_c.index.get_level_values('group').isin([
        'BT474_mV_high_Untreated',
        'BT474_mV_low_Untreated',
        'BT474_mV_Untreated_Unsorted'
    ])]

df_only_interesting_groups_filtered_c = merge(df_only_interesting_groups_filtered_c, df_dict_filtered_c)
clusters_of_interest_filtered_c = df_only_interesting_groups_filtered_c['locus_cluster'].tolist() 
summary(df_only_interesting_groups_filtered_c, clusters_of_interest_filtered_c, gene='all genes - cosine')


start_filtered_c, end_filtered_c, center_coord_filtered_c= start_end_center(df_test_plot_filtered_c)

plot_avg_methylation_profile(
        df= df_filtered_c,
        df_dict=df_dict_filtered_c,
        start=start_filtered_c,
        end=end_filtered_c,
        center_coord=center_coord_filtered_c,
        read_dict=read_dict_filtered_c,
        gene= 'all genes - cosine'
    )